Skip to content

Commit

Permalink
support nested types, improve code generated
Browse files Browse the repository at this point in the history
  • Loading branch information
pwelter34 committed Sep 5, 2024
1 parent f2a9dc6 commit eb0ee18
Show file tree
Hide file tree
Showing 19 changed files with 342 additions and 193 deletions.
34 changes: 33 additions & 1 deletion src/Equatable.SourceGenerator/EquatableGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,13 @@ private static bool SyntacticPredicate(SyntaxNode syntaxNode, CancellationToken
if (context.TargetSymbol is not INamedTypeSymbol targetSymbol)
return null;

var fullyQualified = targetSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
var classNamespace = targetSymbol.ContainingNamespace.ToDisplayString();
var className = targetSymbol.Name;

// support nested types
var containingTypes = GetContainingTypes(targetSymbol);

var baseHashCode = GetBaseHashCodeMethod(targetSymbol);
var baseEquals = GetBaseEqualsMethod(targetSymbol);
var baseEquatable = GetBaseEquatableType(targetSymbol);
Expand All @@ -99,8 +103,10 @@ private static bool SyntacticPredicate(SyntaxNode syntaxNode, CancellationToken
seedHash = (seedHash * HashFactor) + GetFNVHashCode(property.PropertyName);

var entity = new EquatableClass(
FullyQualified: fullyQualified,
EntityNamespace: classNamespace,
EntityName: className,
ContainingTypes: containingTypes,
Properties: propertyArray,
IsRecord: targetSymbol.IsRecord,
IsValueType: targetSymbol.IsValueType,
Expand Down Expand Up @@ -144,7 +150,8 @@ private static IEnumerable<IPropertySymbol> GetProperties(INamedTypeSymbol targe

private static EquatableProperty CreateProperty(IPropertySymbol propertySymbol)
{
var propertyType = propertySymbol.Type.ToDisplayString();
var format = SymbolDisplayFormat.FullyQualifiedFormat.WithMiscellaneousOptions(SymbolDisplayMiscellaneousOptions.IncludeNullableReferenceTypeModifier);
var propertyType = propertySymbol.Type.ToDisplayString(format);
var propertyName = propertySymbol.Name;

// look for custom equality
Expand Down Expand Up @@ -286,6 +293,31 @@ private static bool IsValueType(INamedTypeSymbol targetSymbol)
};
}

private static EquatableArray<ContainingClass> GetContainingTypes(INamedTypeSymbol targetSymbol)
{
if (targetSymbol.ContainingType is null)
return Array.Empty<ContainingClass>();

var list = new List<ContainingClass>();
var currentSymbol = targetSymbol.ContainingType;

while (currentSymbol != null)
{
var containingClass = new ContainingClass(
EntityName: currentSymbol.Name,
IsRecord: currentSymbol.IsRecord,
IsValueType: currentSymbol.IsValueType
);

list.Add(containingClass);

currentSymbol = currentSymbol.ContainingType;
}

list.Reverse();

return list.ToArray();
}

private static IMethodSymbol? GetBaseHashCodeMethod(INamedTypeSymbol targetSymbol)
{
Expand Down
63 changes: 43 additions & 20 deletions src/Equatable.SourceGenerator/EquatableWriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ public static class EquatableWriter
{
public static string Generate(EquatableClass entityClass)
{
if (entityClass == null)
if (entityClass is null)
throw new ArgumentNullException(nameof(entityClass));

var codeBuilder = new IndentedStringBuilder();
Expand All @@ -23,6 +23,19 @@ public static string Generate(EquatableClass entityClass)
.AppendLine("{")
.IncrementIndent();

// support nested types
foreach (var containingClass in entityClass.ContainingTypes)
{
codeBuilder
.Append("partial ")
.AppendIf("record ", containingClass.IsRecord)
.AppendIf("class ", !containingClass.IsValueType)
.AppendIf("struct ", containingClass.IsValueType)
.AppendLine(containingClass.EntityName)
.AppendLine("{")
.IncrementIndent();
}

codeBuilder
.Append("partial ")
.AppendIf("record ", entityClass.IsRecord)
Expand All @@ -34,7 +47,7 @@ public static string Generate(EquatableClass entityClass)
{
codeBuilder
.Append(" : global::System.IEquatable<")
.Append(entityClass.EntityName)
.Append(entityClass.FullyQualified)
.AppendIf("?", !entityClass.IsValueType)
.Append(">");
}
Expand All @@ -50,7 +63,17 @@ public static string Generate(EquatableClass entityClass)

codeBuilder
.DecrementIndent()
.AppendLine("}") // class
.AppendLine("}"); // class

// support nested types
foreach (var containingClass in entityClass.ContainingTypes)
{
codeBuilder
.DecrementIndent()
.AppendLine("}");
}

codeBuilder
.DecrementIndent()
.AppendLine("}"); // namespace

Expand All @@ -69,7 +92,7 @@ private static void GenerateEquatable(IndentedStringBuilder codeBuilder, Equatab
.Append("public ")
.AppendIf("virtual ", entityClass.IsRecord && !entityClass.IsSealed)
.Append("bool Equals(")
.Append(entityClass.EntityName)
.Append(entityClass.FullyQualified)
.AppendIf("?", !entityClass.IsValueType)
.AppendLine(" other)")
.AppendLine("{")
Expand All @@ -83,7 +106,7 @@ private static void GenerateEquatable(IndentedStringBuilder codeBuilder, Equatab
}
else
{
codeBuilder.Append("return other is not null");
codeBuilder.Append("return !(other is null)");
wrote = true;
}

Expand Down Expand Up @@ -195,7 +218,7 @@ private static void GenerateEquatable(IndentedStringBuilder codeBuilder, Equatab

private static void GenerateEquatableFunctions(IndentedStringBuilder codeBuilder, EquatableClass entityClass)
{
if (entityClass == null)
if (entityClass is null)
return;

if (entityClass.Properties.Any(p => p.ComparerType == ComparerTypes.Dictionary))
Expand All @@ -207,7 +230,7 @@ private static void GenerateEquatableFunctions(IndentedStringBuilder codeBuilder
.AppendLine("if (global::System.Object.ReferenceEquals(left, right))")
.AppendLine(" return true;")
.AppendLine()
.AppendLine("if (left == null || right == null)")
.AppendLine("if (left is null || right is null)")
.AppendLine(" return false;")
.AppendLine()
.AppendLine("if (left.Count != right.Count)")
Expand Down Expand Up @@ -244,7 +267,7 @@ private static void GenerateEquatableFunctions(IndentedStringBuilder codeBuilder
.AppendLine("if (global::System.Object.ReferenceEquals(left, right))")
.AppendLine(" return true;")
.AppendLine()
.AppendLine("if (left == null || right == null)")
.AppendLine("if (left is null || right is null)")
.AppendLine(" return false;")
.AppendLine()
.AppendLine("if (left is global::System.Collections.Generic.ISet<T> leftSet)")
Expand All @@ -269,7 +292,7 @@ private static void GenerateEquatableFunctions(IndentedStringBuilder codeBuilder
.AppendLine("if (global::System.Object.ReferenceEquals(left, right))")
.AppendLine(" return true;")
.AppendLine()
.AppendLine("if (left == null || right == null)")
.AppendLine("if (left is null || right is null)")
.AppendLine(" return false;")
.AppendLine()
.AppendLine("return global::System.Linq.Enumerable.SequenceEqual(left, right, global::System.Collections.Generic.EqualityComparer<T>.Default);")
Expand Down Expand Up @@ -300,14 +323,14 @@ private static void GenerateEquals(IndentedStringBuilder codeBuilder, EquatableC
{
codeBuilder
.Append("return obj is ")
.Append(entityClass.EntityName)
.Append(entityClass.FullyQualified)
.AppendLine(" instance && Equals(instance);");
}
else
{
codeBuilder
.Append("return Equals(obj as ")
.Append(entityClass.EntityName)
.Append(entityClass.FullyQualified)
.AppendLine(");");
}

Expand All @@ -324,16 +347,16 @@ private static void GenerateEquals(IndentedStringBuilder codeBuilder, EquatableC
.Append(ThisAssembly.InformationalVersion)
.AppendLine("\")]")
.Append("public static bool operator ==(")
.Append(entityClass.EntityName)
.Append(entityClass.FullyQualified)
.AppendIf("?", !entityClass.IsValueType)
.Append(" left, ")
.Append(entityClass.EntityName)
.Append(entityClass.FullyQualified)
.AppendIf("?", !entityClass.IsValueType)
.AppendLine(" right)")
.AppendLine("{")
.IncrementIndent()
.Append("return global::System.Collections.Generic.EqualityComparer<")
.Append(entityClass.EntityName)
.Append(entityClass.FullyQualified)
.AppendIf("?", !entityClass.IsValueType)
.AppendLine(">.Default.Equals(left, right);")
.DecrementIndent()
Expand All @@ -348,10 +371,10 @@ private static void GenerateEquals(IndentedStringBuilder codeBuilder, EquatableC
.Append(ThisAssembly.InformationalVersion)
.AppendLine("\")]")
.Append("public static bool operator !=(")
.Append(entityClass.EntityName)
.Append(entityClass.FullyQualified)
.AppendIf("?", !entityClass.IsValueType)
.Append(" left, ")
.Append(entityClass.EntityName)
.Append(entityClass.FullyQualified)
.AppendIf("?", !entityClass.IsValueType)
.AppendLine(" right)")
.AppendLine("{")
Expand Down Expand Up @@ -460,7 +483,7 @@ private static void GenerateHashCode(IndentedStringBuilder codeBuilder, Equatabl

private static void GenerateHashCodeFunctions(IndentedStringBuilder codeBuilder, EquatableClass entityClass)
{
if (entityClass == null)
if (entityClass is null)
return;

if (entityClass.Properties.Any(p => p.ComparerType == ComparerTypes.Dictionary))
Expand All @@ -469,7 +492,7 @@ private static void GenerateHashCodeFunctions(IndentedStringBuilder codeBuilder,
.AppendLine("static int DictionaryHashCode<TKey, TValue>(global::System.Collections.Generic.IDictionary<TKey, TValue>? items)")
.AppendLine("{")
.IncrementIndent()
.AppendLine("if (items == null)")
.AppendLine("if (items is null)")
.AppendLine(" return 0;")
.AppendLine();

Expand Down Expand Up @@ -501,7 +524,7 @@ private static void GenerateHashCodeFunctions(IndentedStringBuilder codeBuilder,
.AppendLine("static int HashSetHashCode<T>(global::System.Collections.Generic.IEnumerable<T>? items)")
.AppendLine("{")
.IncrementIndent()
.AppendLine("if (items == null)")
.AppendLine("if (items is null)")
.AppendLine(" return 0;")
.AppendLine()
.Append("int hashCode = ")
Expand All @@ -524,7 +547,7 @@ private static void GenerateHashCodeFunctions(IndentedStringBuilder codeBuilder,
.AppendLine("static int SequenceHashCode<T>(global::System.Collections.Generic.IEnumerable<T>? items)")
.AppendLine("{")
.IncrementIndent()
.AppendLine("if (items == null)")
.AppendLine("if (items is null)")
.AppendLine(" return 0;")
.AppendLine()
.Append("int hashCode = ")
Expand Down
7 changes: 7 additions & 0 deletions src/Equatable.SourceGenerator/Models/ContainingClass.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
namespace Equatable.SourceGenerator.Models;

public record ContainingClass(
string EntityName,
bool IsRecord,
bool IsValueType
);
20 changes: 11 additions & 9 deletions src/Equatable.SourceGenerator/Models/EquatableClass.cs
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
namespace Equatable.SourceGenerator.Models;

public record EquatableClass(
string EntityNamespace,
string EntityName,
EquatableArray<EquatableProperty> Properties,
bool IsRecord,
bool IsValueType,
bool IsSealed,
bool IncludeBaseEqualsMethod,
bool IncludeBaseHashMethod,
int SeedHash
string FullyQualified,
string EntityNamespace,
string EntityName,
EquatableArray<ContainingClass> ContainingTypes,
EquatableArray<EquatableProperty> Properties,
bool IsRecord,
bool IsValueType,
bool IsSealed,
bool IncludeBaseEqualsMethod,
bool IncludeBaseHashMethod,
int SeedHash
);
2 changes: 1 addition & 1 deletion test/Equatable.Entities/Nested.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ namespace Equatable.Entities;

public partial class Nested
{
//[Equatable]
[Equatable]
public partial class Animal
{
public int Id { get; set; }
Expand Down
29 changes: 29 additions & 0 deletions test/Equatable.Generator.Tests/EquatableGeneratorTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,35 @@ public partial class Audit
.ScrubLinesContaining("GeneratedCodeAttribute");
}

[Fact]
public Task GenerateNestedComparer()
{
var source = @"
using Equatable.Attributes;
namespace Equatable.Entities;
public partial class Nested
{
[Equatable]
public partial class Animal
{
public int Id { get; set; }
public string? Name { get; set; }
public string? Type { get; set; }
}
}
";

var (diagnostics, output) = GetGeneratedOutput<EquatableGenerator>(source);

diagnostics.Should().BeEmpty();

return Verifier
.Verify(output)
.UseDirectory("Snapshots")
.ScrubLinesContaining("GeneratedCodeAttribute");
}

private static (ImmutableArray<Diagnostic> Diagnostics, string Output) GetGeneratedOutput<T>(string source)
where T : IIncrementalGenerator, new()
Expand Down
6 changes: 6 additions & 0 deletions test/Equatable.Generator.Tests/EquatableWriterTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ public class EquatableWriterTest
public async Task GenerateBasicUser()
{
var entityClass = new EquatableClass(
FullyQualified: "global::Equatable.Entities.User",
EntityNamespace: "Equatable.Entities",
EntityName: "User",
ContainingTypes: Array.Empty<ContainingClass>(),
Properties: new EquatableArray<EquatableProperty>([
new EquatableProperty("Id", "int"),
new EquatableProperty("FirstName", "string?"),
Expand Down Expand Up @@ -38,8 +40,10 @@ public async Task GenerateBasicUser()
public async Task GenerateUserStringSequence()
{
var entityClass = new EquatableClass(
FullyQualified: "global::Equatable.Entities.User",
EntityNamespace: "Equatable.Entities",
EntityName: "User",
ContainingTypes: Array.Empty<ContainingClass>(),
Properties: new EquatableArray<EquatableProperty>([
new EquatableProperty("Id", "int"),
new EquatableProperty("FirstName", "string?"),
Expand Down Expand Up @@ -68,8 +72,10 @@ public async Task GenerateUserStringSequence()
public async Task GenerateUserImportHashSetDictionary()
{
var entityClass = new EquatableClass(
FullyQualified: "global::Equatable.Entities.UserImport",
EntityNamespace: "Equatable.Entities",
EntityName: "UserImport",
ContainingTypes: Array.Empty<ContainingClass>(),
Properties: new EquatableArray<EquatableProperty>([
new EquatableProperty("EmailAddress", "string", ComparerTypes.String, "OrdinalIgnoreCase"),
new EquatableProperty("DisplayName", "string?"),
Expand Down
Loading

0 comments on commit eb0ee18

Please sign in to comment.