Skip to content

Commit

Permalink
Refactoring existing source generator to use new incremental api
Browse files Browse the repository at this point in the history
  • Loading branch information
jlevier committed Mar 28, 2023
1 parent 7cd022c commit a01c3bf
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 104 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

<ItemGroup>
<PackageReference Include="Microsoft.CodeAnalysis.CSharp.SourceGenerators.Testing.XUnit" Version="1.1.1" />
<PackageReference Include="Microsoft.CodeAnalysis.CSharp.Workspaces" Version="3.9.0" />
<PackageReference Include="Microsoft.CodeAnalysis.CSharp.Workspaces" Version="4.5.0" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.2.0" />
<PackageReference Include="xunit" Version="2.4.1" />
<PackageReference Include="xunit.runner.visualstudio" Version="2.4.5">
Expand Down
4 changes: 2 additions & 2 deletions OneOf.SourceGenerator/OneOf.SourceGenerator.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@
</ItemGroup>

<ItemGroup>
<PackageReference Include="Microsoft.CodeAnalysis.Analyzers" Version="3.3.3">
<PackageReference Include="Microsoft.CodeAnalysis.Analyzers" Version="3.3.4">
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference>
<PackageReference Include="Microsoft.CodeAnalysis.CSharp" Version="3.9.0">
<PackageReference Include="Microsoft.CodeAnalysis.CSharp" Version="4.5.0">
<PrivateAssets>all</PrivateAssets>
</PackageReference>
</ItemGroup>
Expand Down
224 changes: 123 additions & 101 deletions OneOf.SourceGenerator/OneOfGenerator.cs
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Text;
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Text;
using System.Threading;

namespace OneOf.SourceGenerator
{
[Generator]
public class OneOfGenerator : ISourceGenerator
public class OneOfGenerator : IIncrementalGenerator
{
private const string AttributeName = "GenerateOneOfAttribute";
private const string AttributeNamespace = "OneOf";

private readonly string _attributeText = $@"// <auto-generated />
private static readonly string _attributeText = $@"// <auto-generated />
using System;
#pragma warning disable 1591
Expand All @@ -29,158 +30,179 @@ internal sealed class {AttributeName} : Attribute
}}
";

public void Execute(GeneratorExecutionContext context)
public void Initialize(IncrementalGeneratorInitializationContext context)
{
if (context.SyntaxReceiver is not OneOfSyntaxReceiver receiver)
{
return;
}

Compilation compilation = context.Compilation;
context.RegisterPostInitializationOutput(ctx => ctx.AddSource(
$"{AttributeName}.g.cs",
SourceText.From(_attributeText, Encoding.UTF8)));

IncrementalValuesProvider<ClassToGenerate?> classesToGenerate = context.SyntaxProvider
.ForAttributeWithMetadataName(
$"{AttributeNamespace}.{AttributeName}",
predicate: (node, _) => IsCandidateForGeneration(node),
transform: GetClassToGenerate)
.Where(static m => m is not null);

context.RegisterSourceOutput(classesToGenerate,
static (spc, classToGenerate) => Execute(in classToGenerate, spc));
}

INamedTypeSymbol? attributeSymbol =
compilation.GetTypeByMetadataName($"{AttributeNamespace}.{AttributeName}");
private static bool IsCandidateForGeneration(SyntaxNode node)
=> node is ClassDeclarationSyntax classSyntax && classSyntax.AttributeLists.Count > 0;

if (attributeSymbol is null)
private static void Execute(in ClassToGenerate? classToGenerate, SourceProductionContext context)
{
if (classToGenerate is { } validClass)
{
return;
if (validClass.Error is not null)
{
context.ReportDiagnostic(validClass.Error);
return;
}

var result = GetClassCode(in validClass);
context.AddSource($"{validClass.Namespace}_{validClass.Name}.g.cs", SourceText.From(result, Encoding.UTF8));
}
}

List<(INamedTypeSymbol, Location?)> namedTypeSymbols = new();
foreach (ClassDeclarationSyntax classDeclaration in receiver.CandidateClasses)
private static ClassToGenerate? GetClassToGenerate(GeneratorAttributeSyntaxContext context, CancellationToken ct)
{
if (context.TargetSymbol is not INamedTypeSymbol classSymbol)
{
SemanticModel model = compilation.GetSemanticModel(classDeclaration.SyntaxTree);
INamedTypeSymbol? namedTypeSymbol = model.GetDeclaredSymbol(classDeclaration);

AttributeData? attributeData = namedTypeSymbol?.GetAttributes().FirstOrDefault(ad =>
ad.AttributeClass?.Equals(attributeSymbol, SymbolEqualityComparer.Default) != false);

if (attributeData is not null)
{
namedTypeSymbols.Add((namedTypeSymbol!,
attributeData.ApplicationSyntaxReference?.GetSyntax().GetLocation()));
}
return null;
}

foreach ((INamedTypeSymbol namedSymbol, Location? attributeLocation) in namedTypeSymbols)
// Check to see if the class has the attribute we're looking for, otherwise return null and do nothing
if (!classSymbol.GetAttributes().Any(a => a.AttributeClass?.Name == AttributeName
&& a.AttributeClass?.ContainingNamespace.Name == AttributeNamespace))
{
string? classSource = ProcessClass(namedSymbol, context, attributeLocation);
return null;
}

if (classSource is null)
{
continue;
}
ct.ThrowIfCancellationRequested();

context.AddSource($"{namedSymbol.ContainingNamespace}_{namedSymbol.Name}.g.cs", classSource);
if (ClassHasErrors(classSymbol, context, out ClassToGenerate? classWithError))
{
return classWithError;
}

IEnumerable<(ITypeParameterSymbol param, ITypeSymbol arg)> paramArgPairs =
classSymbol.BaseType!.TypeParameters.Zip(classSymbol.BaseType.TypeArguments, (param, arg) => (param, arg));

return new ClassToGenerate(
classSymbol.Name,
classSymbol.ContainingNamespace.ToDisplayString(),
classSymbol.TypeArguments,
classSymbol.BaseType!.TypeArguments,
paramArgPairs);
}

private static string? ProcessClass(INamedTypeSymbol classSymbol, GeneratorExecutionContext context, Location? attributeLocation)
private static bool ClassHasErrors(INamedTypeSymbol classSymbol, GeneratorAttributeSyntaxContext context, out ClassToGenerate? classWithError)
{
attributeLocation ??= Location.None;
classWithError = null;

if (!classSymbol.ContainingSymbol.Equals(classSymbol.ContainingNamespace, SymbolEqualityComparer.Default))
{
CreateDiagnosticError(GeneratorDiagnosticDescriptors.TopLevelError);
return null;
classWithError = CreateError(GeneratorDiagnosticDescriptors.TopLevelError, context.TargetNode.GetLocation(), classSymbol.Name);
return true;
}

if (classSymbol.BaseType is null || classSymbol.BaseType.Name != "OneOfBase" || classSymbol.BaseType.ContainingNamespace.ToString() != "OneOf")
if (classSymbol.BaseType is null || classSymbol.BaseType.Name != "OneOfBase" || classSymbol.BaseType.ContainingNamespace.ToString() != AttributeNamespace)
{
CreateDiagnosticError(GeneratorDiagnosticDescriptors.WrongBaseType);
return null;
classWithError = CreateError(GeneratorDiagnosticDescriptors.WrongBaseType, context.TargetNode.GetLocation(), classSymbol.Name);
return true;
}

ImmutableArray<ITypeSymbol> typeArguments = classSymbol.BaseType.TypeArguments;

foreach (ITypeSymbol typeSymbol in typeArguments)
foreach (ITypeSymbol typeSymbol in classSymbol.BaseType!.TypeArguments)
{
if (typeSymbol.Name == nameof(Object))
{
CreateDiagnosticError(GeneratorDiagnosticDescriptors.ObjectIsOneOfType);
return null;
classWithError = CreateError(GeneratorDiagnosticDescriptors.ObjectIsOneOfType, context.TargetNode.GetLocation(), classSymbol.Name);
return true;
}

if (typeSymbol.TypeKind == TypeKind.Interface)
{
CreateDiagnosticError(GeneratorDiagnosticDescriptors.UserDefinedConversionsToOrFromAnInterfaceAreNotAllowed);
return null;
classWithError = CreateError(GeneratorDiagnosticDescriptors.UserDefinedConversionsToOrFromAnInterfaceAreNotAllowed,
context.TargetNode.GetLocation(), classSymbol.Name);
return true;
}
}

return GenerateClassSource(classSymbol, classSymbol.BaseType.TypeParameters, typeArguments);

void CreateDiagnosticError(DiagnosticDescriptor descriptor)
=> context.ReportDiagnostic(Diagnostic.Create(descriptor, attributeLocation, classSymbol.Name, DiagnosticSeverity.Error));
return false;
}

private static string GenerateClassSource(INamedTypeSymbol classSymbol,
ImmutableArray<ITypeParameterSymbol> typeParameters, ImmutableArray<ITypeSymbol> typeArguments)
{
IEnumerable<(ITypeParameterSymbol param, ITypeSymbol arg)> paramArgPairs =
typeParameters.Zip(typeArguments, (param, arg) => (param, arg));
private static ClassToGenerate CreateError(DiagnosticDescriptor descriptor, Location location, string name)
=> new(Diagnostic.Create(descriptor, location, name, DiagnosticSeverity.Error));

string oneOfGenericPart = GetGenericPart(typeArguments);
private static string GetGenericPart(ImmutableArray<ITypeSymbol> typeArguments) =>
string.Join(", ", typeArguments.Select(x => x.ToDisplayString()));

string classNameWithGenericTypes = $"{classSymbol.Name}{GetOpenGenericPart(classSymbol)}";
private static string? GetOpenGenericPart(ImmutableArray<ITypeSymbol> typeArguments)
{
if (!typeArguments.Any())
{
return null;
}

StringBuilder source = new($@"// <auto-generated />
#pragma warning disable 1591
return $"<{GetGenericPart(typeArguments)}>";
}

namespace {classSymbol.ContainingNamespace.ToDisplayString()}
{{
partial class {classNameWithGenericTypes}");
private static string GetClassCode(in ClassToGenerate classToGenerate)
{
string constructor = $"public {classToGenerate.Name}(OneOf.OneOf<{GetGenericPart(classToGenerate.OneOfBaseTypeArguments)}> _) : base(_) {{ }}";

source.Append($@"
{{
public {classSymbol.Name}(OneOf.OneOf<{oneOfGenericPart}> _) : base(_) {{ }}
");
string classNameWithGenericTypes = $"{classToGenerate.Name}{GetOpenGenericPart(classToGenerate.TypeArguments)}";

foreach ((ITypeParameterSymbol param, ITypeSymbol arg) in paramArgPairs)
StringBuilder sbParamArgPairs = new();
foreach ((ITypeParameterSymbol param, ITypeSymbol arg) in classToGenerate.ParamArgPairs)
{
source.Append($@"
sbParamArgPairs.Append($@"
public static implicit operator {classNameWithGenericTypes}({arg.ToDisplayString()} _) => new {classNameWithGenericTypes}(_);
public static explicit operator {arg.ToDisplayString()}({classNameWithGenericTypes} _) => _.As{param.Name};
");
}

source.Append(@" }
}");
return source.ToString();
}
return $@"// <auto-generated />
#pragma warning disable 1591
private static string GetGenericPart(ImmutableArray<ITypeSymbol> typeArguments) =>
string.Join(", ", typeArguments.Select(x => x.ToDisplayString()));
namespace {classToGenerate.Namespace}
{{
partial class {classNameWithGenericTypes}
{{
{constructor}
{sbParamArgPairs}
}}
}}
";
}

private static string? GetOpenGenericPart(INamedTypeSymbol classSymbol)
internal sealed class ClassToGenerate
{
if (!classSymbol.TypeArguments.Any())
public ClassToGenerate(
string name,
string @namespace,
ImmutableArray<ITypeSymbol> typeArguments,
ImmutableArray<ITypeSymbol> oneOfBaseTypeArguments,
IEnumerable<(ITypeParameterSymbol param, ITypeSymbol arg)> paramArgPairs)
{
return null;
Name = name;
Namespace = @namespace;
TypeArguments = typeArguments;
OneOfBaseTypeArguments = oneOfBaseTypeArguments;
ParamArgPairs = paramArgPairs;
}

return $"<{GetGenericPart(classSymbol.TypeArguments)}>";
}
public ClassToGenerate(Diagnostic error)
: this("", "", default, default, new List<(ITypeParameterSymbol param, ITypeSymbol arg)>())
=> Error = error;

public void Initialize(GeneratorInitializationContext context)
{
context.RegisterForPostInitialization(ctx =>
ctx.AddSource($"{AttributeName}.g.cs", _attributeText));
context.RegisterForSyntaxNotifications(() => new OneOfSyntaxReceiver());
}

internal class OneOfSyntaxReceiver : ISyntaxReceiver
{
public List<ClassDeclarationSyntax> CandidateClasses { get; } = new();

public void OnVisitSyntaxNode(SyntaxNode syntaxNode)
{
if (syntaxNode is ClassDeclarationSyntax { AttributeLists: { Count: > 0 } } classDeclarationSyntax
&& classDeclarationSyntax.Modifiers.Any(SyntaxKind.PartialKeyword))
{
CandidateClasses.Add(classDeclarationSyntax);
}
}
public string Name { get; }
public string Namespace { get; }
public ImmutableArray<ITypeSymbol> TypeArguments { get; }
public ImmutableArray<ITypeSymbol> OneOfBaseTypeArguments { get; }
public IEnumerable<(ITypeParameterSymbol param, ITypeSymbol arg)> ParamArgPairs { get; }
public Diagnostic? Error { get; }
}
}
}
}

0 comments on commit a01c3bf

Please sign in to comment.