Skip to content

Commit

Permalink
Compare records by properties when requested (nunit#4837)
Browse files Browse the repository at this point in the history
* Compare records by properties when requested

* review fixes
  • Loading branch information
Dreamescaper authored Sep 25, 2024
1 parent af7b941 commit b10f221
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/NUnitFramework/Directory.Build.props
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<Project>

<PropertyGroup>
<LangVersion Condition="'$(MSBuildProjectExtension)' == '.csproj'">11</LangVersion>
<LangVersion Condition="'$(MSBuildProjectExtension)' == '.csproj'">12</LangVersion>
<Features>strict</Features>
<SignAssembly>true</SignAssembly>
<AssemblyOriginatorKeyFile>..\..\nunit.snk</AssemblyOriginatorKeyFile>
Expand Down
2 changes: 1 addition & 1 deletion src/NUnitFramework/Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
<PackageVersion Include="DotNetAnalyzers.DocumentationAnalyzers" Version="1.0.0-beta.59" />
<PackageVersion Include="Microsoft.CodeAnalysis.CSharp.CodeStyle" Version="4.10.0" />
<PackageVersion Include="NUnit.Analyzers" Version="4.3.0" />
<PackageVersion Include="StyleCop.Analyzers" Version="1.2.0-beta.507" />
<PackageVersion Include="StyleCop.Analyzers" Version="1.2.0-beta.556" />
</ItemGroup>
<!-- Specific dependencies -->
<ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

using System;
using System.Reflection;
using NUnit.Framework.Internal;

namespace NUnit.Framework.Constraints.Comparers
{
Expand All @@ -17,6 +18,12 @@ public static EqualMethodResult Equal(object x, object y, ref Tolerance toleranc

Type xType = x.GetType();

if (equalityComparer.CompareProperties && TypeHelper.HasCompilerGeneratedEquals(xType))
{
// For record types, when CompareProperties is requested, we ignore generated Equals method and compare by properties.
return EqualMethodResult.TypesNotSupported;
}

if (OverridesEqualsObject(xType))
{
if (tolerance.HasVariance)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

using System;
using System.Reflection;
using NUnit.Framework.Internal;

namespace NUnit.Framework.Constraints.Comparers
{
Expand All @@ -18,6 +19,12 @@ public static EqualMethodResult Equal(object x, object y, ref Tolerance toleranc
Type xType = x.GetType();
Type yType = y.GetType();

if (equalityComparer.CompareProperties && TypeHelper.HasCompilerGeneratedEquals(xType))
{
// For record types, when CompareProperties is requested, we ignore generated Equals method and compare by properties.
return EqualMethodResult.TypesNotSupported;
}

MethodInfo? equals = FirstImplementsIEquatableOfSecond(xType, yType);
if (equals is not null)
{
Expand Down
9 changes: 9 additions & 0 deletions src/NUnitFramework/framework/Internal/TypeHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Text;
using NUnit.Framework.Interfaces;

Expand Down Expand Up @@ -387,5 +388,13 @@ internal static string FullName(this Type type)
{
return type.FullName ?? throw new InvalidOperationException("No name for type: " + type);
}

internal static bool HasCompilerGeneratedEquals(this Type type)
{
var equalsMethod = type.GetMethod(nameof(type.Equals), BindingFlags.Instance | BindingFlags.Public,
null, [type], null);

return equalsMethod?.GetCustomAttribute<CompilerGeneratedAttribute>() is not null;
}
}
}
39 changes: 38 additions & 1 deletion src/NUnitFramework/tests/Assertions/AssertThatTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
using System;
using System.Threading.Tasks;
using NUnit.Framework.Interfaces;
using NUnit.TestData;
using NUnit.Framework.Tests.TestUtilities;
using NUnit.TestData;

namespace NUnit.Framework.Tests.Assertions
{
Expand Down Expand Up @@ -599,6 +599,26 @@ public void AssertWithCyclicRecursiveClass()
Assert.That(list1, Is.EqualTo(list2).UsingPropertiesComparer());
}

[Test]
public void AssertRecordsComparingProperties()
{
var record1 = new Record("Name", [1, 2, 3]);
var record2 = new Record("Name", [1, 2, 3]);

Assert.That(record1, Is.Not.EqualTo(record2)); // Record's generated method does not handle collections
Assert.That(record1, Is.EqualTo(record2).UsingPropertiesComparer());
}

[Test]
public void AssertRecordsComparingProperties_WhenRecordHasUserDefinedEqualsMethod()
{
var record1 = new ParentRecord(new RecordWithOverriddenEquals("Name"), [1, 2, 3]);
var record2 = new ParentRecord(new RecordWithOverriddenEquals("NAME"), [1, 2, 3]);

Assert.That(record1, Is.Not.EqualTo(record2)); // ParentRecord's generated method does not handle collections
Assert.That(record1, Is.EqualTo(record2).UsingPropertiesComparer());
}

private sealed class LinkedList
{
public LinkedList(int value, LinkedList? next = null)
Expand Down Expand Up @@ -655,6 +675,23 @@ public void TestPropertyFailureSecondLevel()
*/
}

private record Record(string Name, int[] Collection);

private record ParentRecord(RecordWithOverriddenEquals Child, int[] Collection);

private record RecordWithOverriddenEquals(string Name)
{
public virtual bool Equals(RecordWithOverriddenEquals? other)
{
return string.Equals(Name, other?.Name, StringComparison.OrdinalIgnoreCase);
}

public override int GetHashCode()
{
return Name.ToUpperInvariant().GetHashCode();
}
}

private sealed class ParentClass
{
public ParentClass(ChildClass one, ChildClass two)
Expand Down
62 changes: 62 additions & 0 deletions src/NUnitFramework/tests/Internal/TypeHelperTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -84,5 +84,67 @@ private class Generic<T>
}

#endregion

#region HasCompilerGeneratedEquals

[TestCase(typeof(RecordClass), ExpectedResult = true)]
[TestCase(typeof(RecordStruct), ExpectedResult = true)]
[TestCase(typeof(RecordWithProperties), ExpectedResult = true)]
[TestCase(typeof(RecordWithOverriddenEquals), ExpectedResult = false)]
[TestCase(typeof(int), ExpectedResult = false)]
[TestCase(typeof(int[]), ExpectedResult = false)]
[TestCase(typeof(DtoClass), ExpectedResult = false)]
[TestCase(typeof(ClassWithPrimaryConstructor), ExpectedResult = false)]
[TestCase(typeof(ClassWithOverriddenEquals), ExpectedResult = false)]
public bool HasCompilerGeneratedEqualsTests(Type type) => TypeHelper.HasCompilerGeneratedEquals(type);

private class DtoClass
{
public string? Name { get; set; }
}

private class ClassWithPrimaryConstructor(string name)
{
public string Name => name;
}

private class ClassWithOverriddenEquals
{
public string? Name { get; set; }

public override bool Equals(object? obj)
{
return obj is ClassWithOverriddenEquals other && other.Name == Name;
}

public override int GetHashCode()
{
return 539060726 + EqualityComparer<string?>.Default.GetHashCode(Name ?? string.Empty);
}
}

private record class RecordClass(string Name);

private record struct RecordStruct(string Name);

private record RecordWithProperties
{
public string? Name { get; set; }
}

private record RecordWithOverriddenEquals(string Name)
{
public virtual bool Equals(RecordWithOverriddenEquals? other)
{
return string.Equals(Name, other?.Name, StringComparison.OrdinalIgnoreCase);
}

public override int GetHashCode()
{
return Name.ToUpperInvariant().GetHashCode();
}
}

#endregion
}
}

0 comments on commit b10f221

Please sign in to comment.