Skip to content

Commit

Permalink
fix: dictionary analyzer
Browse files Browse the repository at this point in the history
  • Loading branch information
ronimizy committed Oct 11, 2024
1 parent 26175dc commit 621083c
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 28 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
using System.Collections.Generic;

namespace SourceKit.Analyzers.Collections.Samples.Dictionary;

public class DictionaryStringString
{
public Dictionary<string, string> Dict { get; set; } = new();
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ public class DictionaryKeyTypeMustImplementEquatableAnalyzer : DiagnosticAnalyze
DiagnosticSeverity.Error,
true);

public override ImmutableArray<DiagnosticDescriptor> SupportedDiagnostics { get; } =
ImmutableArray.Create(Descriptor);
public override ImmutableArray<DiagnosticDescriptor> SupportedDiagnostics { get; } = [Descriptor];

public override void Initialize(AnalysisContext context)
{
Expand All @@ -37,46 +36,51 @@ private void AnalyzeGeneric(SyntaxNodeAnalysisContext context)
{
var node = (GenericNameSyntax)context.Node;

if (node.Identifier.Text != "Dictionary")
if (context.SemanticModel.GetDeclaredSymbol(node) is not INamedTypeSymbol symbol)
return;

var dictionaryTypeSymbol = GetSymbolFromContext(context, node) as ITypeSymbol;

if (dictionaryTypeSymbol is null)
if (TryGetDictionaryKeySymbol(symbol, typeof(Dictionary<,>), context, out INamedTypeSymbol? keySymbol) is false
& TryGetDictionaryKeySymbol(symbol, typeof(IReadOnlyDictionary<,>), context, out keySymbol) is false
& TryGetDictionaryKeySymbol(symbol, typeof(IDictionary<,>), context, out keySymbol) is false)
{
return;
}

var keyTypeSymbol = dictionaryTypeSymbol is INamedTypeSymbol namedTypeSymbol
? namedTypeSymbol.TypeArguments.FirstOrDefault()
: null;

if (keyTypeSymbol is null || keyTypeSymbol.MetadataName == "TKey")
if (keySymbol is null || keySymbol.MetadataName is "TKey")
return;
if (keyTypeSymbol.TypeKind is TypeKind.Enum)

if (keySymbol.TypeKind is TypeKind.Enum)
return;

var interfaceNamedType = context.Compilation.GetTypeSymbol(typeof(IEquatable<>));
INamedTypeSymbol equatableSymbol = context.Compilation.GetTypeSymbol(typeof(IEquatable<>));

INamedTypeSymbol madeEquatableSymbol = equatableSymbol
.Construct(keySymbol.WithNullableAnnotation(NullableAnnotation.None));

var equatableInterfaces = keyTypeSymbol.FindAssignableTypesConstructedFrom(interfaceNamedType);
IEnumerable<INamedTypeSymbol> foundEquatableSymbols = keySymbol
.FindAssignableTypesConstructedFrom(equatableSymbol);

var isThereRightEquatableInterface =
equatableInterfaces
.Select(s => s.TypeArguments.First())
.Any(s => keyTypeSymbol.IsAssignableTo(s));
bool hasCorrectEquatableImplementation = foundEquatableSymbols
.Select(x => x.TypeArguments.First())
.Any(x => madeEquatableSymbol.Equals(x, SymbolEqualityComparer.Default));

if (isThereRightEquatableInterface)
if (hasCorrectEquatableImplementation is false)
return;

var diag = Diagnostic.Create(Descriptor, node.GetLocation());
context.ReportDiagnostic(diag);
}

private static ISymbol? GetSymbolFromContext(SyntaxNodeAnalysisContext context, SyntaxNode node)
private static bool TryGetDictionaryKeySymbol(
INamedTypeSymbol nameSymbol,
Type dictionaryType,
SyntaxNodeAnalysisContext context,
out INamedTypeSymbol? keySymbol)
{
var model = context.SemanticModel;
var symbolInfo = model.GetSymbolInfo(node);
var symbol = symbolInfo.Symbol;
INamedTypeSymbol dictionarySymbol = context.Compilation.GetTypeSymbol(dictionaryType);
INamedTypeSymbol? implementationSymbol = nameSymbol.FindAssignableTypeConstructedFrom(dictionarySymbol);

return symbol;
keySymbol = implementationSymbol?.TypeArguments.FirstOrDefault() as INamedTypeSymbol;
return keySymbol is not null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ private void GenerateForType(GeneratorExecutionContext context, INamedTypeSymbol
CompilationUnitSyntax compilationUnit = _chain.Process(fileCommand);
string fileName = GetFileName(symbol.Name);

context.AddSource(fileName, compilationUnit.NormalizeWhitespace().ToFullString().Replace("\r\n", "\n"));
context.AddSource(fileName, compilationUnit.NormalizeWhitespace(eol: "\n").ToFullString());
}
catch (Exception e)
{
Expand Down
12 changes: 12 additions & 0 deletions tests/SourceKit.Analyzers.Collections.Tests/CollectionsTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,18 @@ await AnalyzerTest
.RunAsync();
}

[Fact]
public async Task DictionaryKeyType_ShouldNotReportDiagnostic_WhenKeyIsString()
{
SourceFile sourceFile = await SourceFile.LoadAsync(
"SourceKit.Analyzers.Collections.Samples/Dictionary/DictionaryStringString.cs");

await AnalyzerTest
.WithSource(sourceFile)
.Build()
.RunAsync();
}

[Fact]
public async Task DictionaryCustomKeyType_ShouldReportDiagnostic_WhenTypeImplementsOtherEquatable()
{
Expand Down
4 changes: 2 additions & 2 deletions tests/SourceKit.Tests.Common/SourceFile.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ namespace SourceKit.Tests.Common;
public readonly record struct SourceFile(string Name, string Content, Encoding? Encoding = null)
{
public string FilePath { get; init; } = Name;

public static async Task<SourceFile> LoadAsync(string path)
{
string name = Path.GetFileName(path);
string content = await File.ReadAllTextAsync(path);

return new SourceFile(name, content, null) { FilePath = path };
return new SourceFile(name, content) { FilePath = path };
}

public static implicit operator (string, SourceText)(SourceFile sourceFile)
Expand Down

0 comments on commit 621083c

Please sign in to comment.