diff --git a/src/RoslynContractFix/ContractFix/ContractFix.Vsix/source.extension.vsixmanifest b/src/RoslynContractFix/ContractFix/ContractFix.Vsix/source.extension.vsixmanifest index 8ce1274..084ce58 100644 --- a/src/RoslynContractFix/ContractFix/ContractFix.Vsix/source.extension.vsixmanifest +++ b/src/RoslynContractFix/ContractFix/ContractFix.Vsix/source.extension.vsixmanifest @@ -1,7 +1,7 @@ - + ContractFix CodeContracts remover diff --git a/src/RoslynContractFix/ContractFix/ContractFix/CodeContractFromBase/CodeContractFromBaseAnalyzer.cs b/src/RoslynContractFix/ContractFix/ContractFix/CodeContractFromBase/CodeContractFromBaseAnalyzer.cs index c15beb9..24e662c 100644 --- a/src/RoslynContractFix/ContractFix/ContractFix/CodeContractFromBase/CodeContractFromBaseAnalyzer.cs +++ b/src/RoslynContractFix/ContractFix/ContractFix/CodeContractFromBase/CodeContractFromBaseAnalyzer.cs @@ -48,7 +48,7 @@ public override void Initialize(AnalysisContext context) private static IEnumerable GetInterfaceImplementation(IMethodSymbol method) { return method.ContainingType.AllInterfaces.SelectMany(@interface => @interface.GetMembers().OfType()). - Where(interfaceMethod => method.ContainingType.FindImplementationForInterfaceMember(interfaceMethod).Equals(method)); + Where(interfaceMethod => method.ContainingType.FindImplementationForInterfaceMember(interfaceMethod)?.Equals(method) ?? false); } private static IEnumerable GetOverridenMethods(IMethodSymbol method) { diff --git a/src/RoslynContractFix/ContractFix/ContractFix/RequiresGenericToIfThrow/RequiresGenericToIfThrowCodeFixProvider.cs b/src/RoslynContractFix/ContractFix/ContractFix/RequiresGenericToIfThrow/RequiresGenericToIfThrowCodeFixProvider.cs index a2e9526..016ebe6 100644 --- a/src/RoslynContractFix/ContractFix/ContractFix/RequiresGenericToIfThrow/RequiresGenericToIfThrowCodeFixProvider.cs +++ b/src/RoslynContractFix/ContractFix/ContractFix/RequiresGenericToIfThrow/RequiresGenericToIfThrowCodeFixProvider.cs @@ -51,6 +51,17 @@ public sealed override async Task RegisterCodeFixesAsync(CodeFixContext context) context.Diagnostics); } + private static bool IsComparisonExpr(ExpressionSyntax expr) + { + var exprKind = expr.Kind(); + return exprKind == SyntaxKind.NotEqualsExpression || + exprKind == SyntaxKind.EqualsExpression || + exprKind == SyntaxKind.GreaterThanExpression || + exprKind == SyntaxKind.GreaterThanOrEqualExpression || + exprKind == SyntaxKind.LessThanExpression || + exprKind == SyntaxKind.LessThanOrEqualExpression; + } + private static ExpressionSyntax SmartNotExpression(ExpressionSyntax expr, SyntaxGenerator generator) { if (expr is BinaryExpressionSyntax binary) @@ -69,8 +80,17 @@ private static ExpressionSyntax SmartNotExpression(ExpressionSyntax expr, Syntax return (ExpressionSyntax)generator.GreaterThanOrEqualExpression(binary.Left, binary.Right); case SyntaxKind.LessThanOrEqualExpression: return (ExpressionSyntax)generator.GreaterThanExpression(binary.Left, binary.Right); + case SyntaxKind.LogicalOrExpression when IsComparisonExpr(binary.Left) && IsComparisonExpr(binary.Right): + return (ExpressionSyntax)generator.LogicalAndExpression(SmartNotExpression(binary.Left, generator), SmartNotExpression(binary.Right, generator)); + case SyntaxKind.LogicalAndExpression when IsComparisonExpr(binary.Left) && IsComparisonExpr(binary.Right): + return (ExpressionSyntax)generator.LogicalOrExpression(SmartNotExpression(binary.Left, generator), SmartNotExpression(binary.Right, generator)); } } + else if (expr is PrefixUnaryExpressionSyntax prefixUnary) + { + if (prefixUnary.IsKind(SyntaxKind.LogicalNotExpression)) + return prefixUnary.Operand; + } return (ExpressionSyntax)generator.LogicalNotExpression(expr); } diff --git a/src/TestSolution/TestProject/Program.cs b/src/TestSolution/TestProject/Program.cs index 0e4963a..13a3dbc 100644 --- a/src/TestSolution/TestProject/Program.cs +++ b/src/TestSolution/TestProject/Program.cs @@ -24,6 +24,7 @@ static void Test(string val) static void ContractTest(string val, int data) { Contract.Requires(val != null, "aaa" + "ab"); + Contract.Requires(!string.IsNullOrEmpty(val)); Contract.Requires(data >= 0); Contract.Requires(data != 0 && val != "fff"); Contract.Requires(data < 100);