Skip to content

Commit

Permalink
Fix performance regression in AssertJ recipes
Browse files Browse the repository at this point in the history
Since commit 368384a , AssertJ recipes are really slower since calling contextSensitive() method  disable a cache on JavaTemplate.

Fixes:
- Make sure there is a static import org.assertj.core.api.Assertions.* , even if not referenced (see #491 and #479)
- reintroduced assertionsParser cache in JUnitAssertEqualsToAssertThat
  • Loading branch information
philippe-granet committed Apr 14, 2024
1 parent 52f6e22 commit a992697
Show file tree
Hide file tree
Showing 10 changed files with 53 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu
Expression expected = args.get(0);
Expression actual = args.get(1);

maybeAddImport("org.assertj.core.api.Assertions", "assertThat");
// Make sure there is a static import for "org.assertj.core.api.Assertions.assertThat" (even if not referenced)
maybeAddImport("org.assertj.core.api.Assertions", "assertThat", false);
maybeRemoveImport(JUNIT_QUALIFIED_ASSERTIONS_CLASS_NAME);

if (args.size() == 2) {
Expand All @@ -93,7 +94,7 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu
.build()
.apply(getCursor(), method.getCoordinates().replace(), actual, message, expected);
} else if (args.size() == 3) {
maybeAddImport("org.assertj.core.api.Assertions", "within");
maybeAddImport("org.assertj.core.api.Assertions", "within", false);
// assert is using floating points with a delta and no message.
return JavaTemplate.builder("assertThat(#{anyArray()}).containsExactly(#{anyArray()}, within(#{any()}));")
.staticImports("org.assertj.core.api.Assertions.assertThat", "org.assertj.core.api.Assertions.within")
Expand All @@ -104,7 +105,7 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu

// The assertEquals is using a floating point with a delta argument and a message.
Expression message = args.get(3);
maybeAddImport("org.assertj.core.api.Assertions", "within");
maybeAddImport("org.assertj.core.api.Assertions", "within", false);

JavaTemplate.Builder template = TypeUtils.isString(message.getType()) ?
JavaTemplate.builder("assertThat(#{anyArray()}).as(#{any(String)}).containsExactly(#{anyArray()}, within(#{any()}));") :
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,16 @@ public TreeVisitor<?, ExecutionContext> getVisitor() {
}

public static class AssertEqualsToAssertThatVisitor extends JavaIsoVisitor<ExecutionContext> {
private JavaParser.Builder<?, ?> assertionsParser;

private JavaParser.Builder<?, ?> assertionsParser(ExecutionContext ctx) {
if (assertionsParser == null) {
assertionsParser = JavaParser.fromJavaVersion()
.classpathFromResources(ctx, "assertj-core-3.24");
}
return assertionsParser;
}

private static final MethodMatcher JUNIT_ASSERT_EQUALS = new MethodMatcher("org.junit.jupiter.api.Assertions" + " assertEquals(..)");

@Override
Expand All @@ -63,13 +73,14 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu

//always add the import (even if not referenced)
maybeAddImport("org.assertj.core.api.Assertions", "assertThat", false);

// Remove import for "org.junit.jupiter.api.Assertions" if no longer used.
maybeRemoveImport("org.junit.jupiter.api.Assertions");

if (args.size() == 2) {
return JavaTemplate.builder("assertThat(#{any()}).isEqualTo(#{any()});")
.contextSensitive()
.staticImports("org.assertj.core.api.Assertions.assertThat")
.javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24"))
.javaParser(assertionsParser(ctx))
.build()
.apply(getCursor(), method.getCoordinates().replace(), actual, expected);
} else if (args.size() == 3 && !isFloatingPointType(args.get(2))) {
Expand All @@ -78,10 +89,9 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu
JavaTemplate.builder("assertThat(#{any()}).as(#{any(String)}).isEqualTo(#{any()});") :
JavaTemplate.builder("assertThat(#{any()}).as(#{any(java.util.function.Supplier)}).isEqualTo(#{any()});");
return template
.contextSensitive()
.staticImports("org.assertj.core.api.Assertions.assertThat")
.imports("java.util.function.Supplier")
.javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24"))
.javaParser(assertionsParser(ctx))
.build()
.apply(
getCursor(),
Expand All @@ -91,11 +101,11 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu
expected
);
} else if (args.size() == 3) {
maybeAddImport("org.assertj.core.api.Assertions", "within");
//always add the import (even if not referenced)
maybeAddImport("org.assertj.core.api.Assertions", "within", false);
return JavaTemplate.builder("assertThat(#{any()}).isCloseTo(#{any()}, within(#{any()}));")
.contextSensitive()
.staticImports("org.assertj.core.api.Assertions.assertThat", "org.assertj.core.api.Assertions.within")
.javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24"))
.javaParser(assertionsParser(ctx))
.build()
.apply(getCursor(), method.getCoordinates().replace(), actual, expected, args.get(2));

Expand All @@ -104,15 +114,15 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu
// The assertEquals is using a floating point with a delta argument and a message.
Expression message = args.get(3);

maybeAddImport("org.assertj.core.api.Assertions", "within");
//always add the import (even if not referenced)
maybeAddImport("org.assertj.core.api.Assertions", "within", false);
JavaTemplate.Builder template = TypeUtils.isString(message.getType()) ?
JavaTemplate.builder("assertThat(#{any()}).as(#{any(String)}).isCloseTo(#{any()}, within(#{any()}));") :
JavaTemplate.builder("assertThat(#{any()}).as(#{any(java.util.function.Supplier)}).isCloseTo(#{any()}, within(#{any()}));");
return template
.contextSensitive()
.staticImports("org.assertj.core.api.Assertions.assertThat", "org.assertj.core.api.Assertions.within")
.imports("java.util.function.Supplier")
.javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24"))
.javaParser(assertionsParser(ctx))
.build()
.apply(
getCursor(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,10 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu
);
}

maybeAddImport("org.assertj.core.api.Assertions", "assertThat");
//Make sure there is a static import for "org.assertj.core.api.Assertions.assertThat" (even if not referenced)
maybeAddImport("org.assertj.core.api.Assertions", "assertThat", false););

// Remove import for "org.junit.jupiter.api.Assertions" if no longer used.
maybeRemoveImport("org.junit.jupiter.api.Assertions");

return method;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu

if (args.size() == 2) {
method = JavaTemplate.builder("assertThat(#{any()}).isNotEqualTo(#{any()});")
.contextSensitive()
.staticImports("org.assertj.core.api.Assertions.assertThat")
.javaParser(assertionsParser(ctx))
.build()
Expand All @@ -93,7 +92,6 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu


method = template
.contextSensitive()
.staticImports("org.assertj.core.api.Assertions.assertThat")
.javaParser(assertionsParser(ctx))
.build()
Expand All @@ -106,7 +104,6 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu
);
} else if (args.size() == 3) {
method = JavaTemplate.builder("assertThat(#{any()}).isNotCloseTo(#{any()}, within(#{any()}));")
.contextSensitive()
.staticImports("org.assertj.core.api.Assertions.assertThat", "org.assertj.core.api.Assertions.within")
.javaParser(assertionsParser(ctx))
.build()
Expand All @@ -117,7 +114,7 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu
expected,
args.get(2)
);
maybeAddImport("org.assertj.core.api.Assertions", "within");
maybeAddImport("org.assertj.core.api.Assertions", "within", false);
} else {
Expression message = args.get(3);

Expand All @@ -126,7 +123,6 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu
JavaTemplate.builder("assertThat(#{any()}).as(#{any(java.util.function.Supplier)}).isNotCloseTo(#{any()}, within(#{any()}));");

method = template
.contextSensitive()
.staticImports("org.assertj.core.api.Assertions.assertThat", "org.assertj.core.api.Assertions.within")
.javaParser(assertionsParser(ctx))
.build()
Expand All @@ -139,12 +135,13 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu
args.get(2)
);

maybeAddImport("org.assertj.core.api.Assertions", "within");
maybeAddImport("org.assertj.core.api.Assertions", "within", false);
}

//Make sure there is a static import for "org.assertj.core.api.Assertions.assertThat"
maybeAddImport("org.assertj.core.api.Assertions", "assertThat");
//And if there are no longer references to the JUnit assertions class, we can remove the import.
//Make sure there is a static import for "org.assertj.core.api.Assertions.assertThat" (even if not referenced)
maybeAddImport("org.assertj.core.api.Assertions", "assertThat", false);

// Remove import for "org.junit.jupiter.api.Assertions" if no longer used.
maybeRemoveImport("org.junit.jupiter.api.Assertions");

return method;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu

if (args.size() == 1) {
method = JavaTemplate.builder("assertThat(#{any()}).isNotNull();")
.contextSensitive()
.staticImports("org.assertj.core.api.Assertions.assertThat")
.javaParser(assertionsParser(ctx))
.build()
Expand All @@ -89,7 +88,6 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu
JavaTemplate.builder("assertThat(#{any()}).as(#{any(java.util.function.Supplier)}).isNotNull();");

method = template
.contextSensitive()
.staticImports("org.assertj.core.api.Assertions.assertThat")
.javaParser(assertionsParser(ctx))
.build()
Expand All @@ -101,8 +99,11 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu
);
}

//Make sure there is a static import for "org.assertj.core.api.Assertions.assertThat" (even if not referenced)
maybeAddImport("org.assertj.core.api.Assertions", "assertThat", false);

//And if there are no longer references to the JUnit assertions class, we can remove the import.
maybeRemoveImport("org.junit.jupiter.api.Assertions");
maybeAddImport("org.assertj.core.api.Assertions", "assertThat");

return method;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu

if (args.size() == 1) {
method = JavaTemplate.builder("assertThat(#{any()}).isNull();")
.contextSensitive()
.staticImports("org.assertj.core.api.Assertions.assertThat")
.javaParser(assertionsParser(ctx))
.build()
Expand All @@ -88,7 +87,6 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu
JavaTemplate.builder("assertThat(#{any()}).as(#{any(java.util.function.Supplier)}).isNull();");

method = template
.contextSensitive()
.staticImports("org.assertj.core.api.Assertions.assertThat")
.javaParser(assertionsParser(ctx))
.build()
Expand All @@ -100,12 +98,12 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu
);
}

// Make sure there is a static import for "org.assertj.core.api.Assertions.assertThat" (even if not referenced)
maybeAddImport("org.assertj.core.api.Assertions", "assertThat", false);

// Remove import for "org.junit.jupiter.api.Assertions" if no longer used.
maybeRemoveImport("org.junit.jupiter.api.Assertions");

// Make sure there is a static import for "org.assertj.core.api.Assertions.assertThat".
maybeAddImport("org.assertj.core.api.Assertions", "assertThat");

return method;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu

if (args.size() == 2) {
method = JavaTemplate.builder("assertThat(#{any()}).isSameAs(#{any()});")
.contextSensitive()
.staticImports("org.assertj.core.api.Assertions.assertThat")
.javaParser(assertionsParser(ctx))
.build()
Expand All @@ -90,7 +89,6 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu
JavaTemplate.builder("assertThat(#{any()}).as(#{any(java.util.function.Supplier)}).isSameAs(#{any()});");

method = template
.contextSensitive()
.staticImports("org.assertj.core.api.Assertions.assertThat")
.javaParser(assertionsParser(ctx))
.build()
Expand All @@ -103,8 +101,11 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu
);
}

maybeRemoveImport("org.junit.jupiter.api.Assertions");
maybeAddImport("org.assertj.core.api.Assertions", "assertThat");
// Make sure there is a static import for "org.assertj.core.api.Assertions.assertThat" (even if not referenced)
maybeAddImport("org.assertj.core.api.Assertions", "assertThat", false);

// Remove import for "org.junit.jupiter.api.Assertions" if no longer used.
maybeRemoveImport(JUNIT_QUALIFIED_ASSERTIONS_CLASS_NAME);

return method;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ && getCursor().getParentTreeCursor().getValue() instanceof J.Block) {
mi.getCoordinates().replace(),
mi.getArguments().get(0), executable
);
maybeAddImport("org.assertj.core.api.AssertionsForClassTypes", "assertThatExceptionOfType");
maybeAddImport("org.assertj.core.api.AssertionsForClassTypes", "assertThatExceptionOfType", false);
maybeRemoveImport("org.junit.jupiter.api.Assertions.assertThrows");
maybeRemoveImport("org.junit.jupiter.api.Assertions");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,10 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu
);
}

maybeAddImport("org.assertj.core.api.Assertions", "assertThat");
//Make sure there is a static import for "org.assertj.core.api.Assertions.assertThat" (even if not referenced)
maybeAddImport("org.assertj.core.api.Assertions", "assertThat", false);

// Remove import for "org.junit.jupiter.api.Assertions" if no longer used.
maybeRemoveImport("org.junit.jupiter.api.Assertions");

return method;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu
method.getCoordinates().replace(),
arguments.toArray()
);
maybeAddImport("org.assertj.core.api.Assertions", "fail");
//Make sure there is a static import for "org.assertj.core.api.Assertions.assertThat" (even if not referenced)
maybeAddImport("org.assertj.core.api.Assertions", "fail", false);
return super.visitMethodInvocation(method, ctx);
}
}
Expand Down

0 comments on commit a992697

Please sign in to comment.