diff --git a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertEqualsToAssertThat.java b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertEqualsToAssertThat.java index 70c122382..51a6d85c5 100644 --- a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertEqualsToAssertThat.java +++ b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertEqualsToAssertThat.java @@ -49,17 +49,6 @@ public TreeVisitor getVisitor() { } public static class AssertEqualsToAssertThatVisitor extends JavaIsoVisitor { - private JavaParser.Builder assertionsParser; - - private JavaParser.Builder assertionsParser(ExecutionContext ctx) { - if (assertionsParser == null) { - assertionsParser = JavaParser.fromJavaVersion() - .classpathFromResources(ctx, "assertj-core-3.24"); - } - return assertionsParser; - } - - private static final MethodMatcher JUNIT_ASSERT_EQUALS = new MethodMatcher("org.junit.jupiter.api.Assertions" + " assertEquals(..)"); @Override @@ -79,7 +68,7 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu if (args.size() == 2) { return JavaTemplate.builder("assertThat(#{any()}).isEqualTo(#{any()});") .staticImports("org.assertj.core.api.Assertions.assertThat") - .javaParser(assertionsParser(ctx)) + .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24")) .build() .apply(getCursor(), method.getCoordinates().replace(), actual, expected); } else if (args.size() == 3 && !isFloatingPointType(args.get(2))) { @@ -90,7 +79,7 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu return template .staticImports("org.assertj.core.api.Assertions.assertThat") .imports("java.util.function.Supplier") - .javaParser(assertionsParser(ctx)) + .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24")) .build() .apply( getCursor(), @@ -103,7 +92,7 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu maybeAddImport("org.assertj.core.api.Assertions", "within"); return JavaTemplate.builder("assertThat(#{any()}).isCloseTo(#{any()}, within(#{any()}));") .staticImports("org.assertj.core.api.Assertions.assertThat", "org.assertj.core.api.Assertions.within") - .javaParser(assertionsParser(ctx)) + .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24")) .build() .apply(getCursor(), method.getCoordinates().replace(), actual, expected, args.get(2)); @@ -119,7 +108,7 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu return template .staticImports("org.assertj.core.api.Assertions.assertThat", "org.assertj.core.api.Assertions.within") .imports("java.util.function.Supplier") - .javaParser(assertionsParser(ctx)) + .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24")) .build() .apply( getCursor(), diff --git a/src/test/java/org/openrewrite/java/testing/assertj/JUnitAssertEqualsToAssertThatTest.java b/src/test/java/org/openrewrite/java/testing/assertj/JUnitAssertEqualsToAssertThatTest.java index 4d2d52470..b66a3ac75 100644 --- a/src/test/java/org/openrewrite/java/testing/assertj/JUnitAssertEqualsToAssertThatTest.java +++ b/src/test/java/org/openrewrite/java/testing/assertj/JUnitAssertEqualsToAssertThatTest.java @@ -360,45 +360,40 @@ private File notification() { } @Test - @Issue("479") + @Issue("https://github.com/openrewrite/rewrite-testing-frameworks/issues/479") void shouldImportWhenCustomClassIsUsed() { //language=java rewriteRun( + // The JavaParer in JavaTemplate only has AssertJ on the classpath, and for now is not .contextSenstive() spec -> spec.typeValidationOptions(TypeValidation.none()), java( """ - package org.example; - import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; class ATest { - - @Test void testEquals() { + @Test + void testEquals() { Assertions.assertEquals(new OwnClass(), new OwnClass()); } public record OwnClass(String a) { - public OwnClass() {this("1");} } } """, """ - package org.example; - import org.junit.jupiter.api.Test; import static org.assertj.core.api.Assertions.assertThat; class ATest { - - @Test void testEquals() { + @Test + void testEquals() { assertThat(new OwnClass()).isEqualTo(new OwnClass()); } public record OwnClass(String a) { - public OwnClass() {this("1");} } }