diff --git a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertEqualsToAssertThat.java b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertEqualsToAssertThat.java index 51a6d85c5..de96ac7fd 100644 --- a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertEqualsToAssertThat.java +++ b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertEqualsToAssertThat.java @@ -67,6 +67,7 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu if (args.size() == 2) { return JavaTemplate.builder("assertThat(#{any()}).isEqualTo(#{any()});") + .contextSensitive() .staticImports("org.assertj.core.api.Assertions.assertThat") .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24")) .build() @@ -77,6 +78,7 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu JavaTemplate.builder("assertThat(#{any()}).as(#{any(String)}).isEqualTo(#{any()});") : JavaTemplate.builder("assertThat(#{any()}).as(#{any(java.util.function.Supplier)}).isEqualTo(#{any()});"); return template + .contextSensitive() .staticImports("org.assertj.core.api.Assertions.assertThat") .imports("java.util.function.Supplier") .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24")) @@ -91,6 +93,7 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu } else if (args.size() == 3) { maybeAddImport("org.assertj.core.api.Assertions", "within"); return JavaTemplate.builder("assertThat(#{any()}).isCloseTo(#{any()}, within(#{any()}));") + .contextSensitive() .staticImports("org.assertj.core.api.Assertions.assertThat", "org.assertj.core.api.Assertions.within") .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24")) .build() @@ -106,6 +109,7 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu JavaTemplate.builder("assertThat(#{any()}).as(#{any(String)}).isCloseTo(#{any()}, within(#{any()}));") : JavaTemplate.builder("assertThat(#{any()}).as(#{any(java.util.function.Supplier)}).isCloseTo(#{any()}, within(#{any()}));"); return template + .contextSensitive() .staticImports("org.assertj.core.api.Assertions.assertThat", "org.assertj.core.api.Assertions.within") .imports("java.util.function.Supplier") .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3.24")) diff --git a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertNotEqualsToAssertThat.java b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertNotEqualsToAssertThat.java index 3401f5048..20f83f869 100644 --- a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertNotEqualsToAssertThat.java +++ b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertNotEqualsToAssertThat.java @@ -74,6 +74,7 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu if (args.size() == 2) { method = JavaTemplate.builder("assertThat(#{any()}).isNotEqualTo(#{any()});") + .contextSensitive() .staticImports("org.assertj.core.api.Assertions.assertThat") .javaParser(assertionsParser(ctx)) .build() @@ -92,6 +93,7 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu method = template + .contextSensitive() .staticImports("org.assertj.core.api.Assertions.assertThat") .javaParser(assertionsParser(ctx)) .build() @@ -104,6 +106,7 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu ); } else if (args.size() == 3) { method = JavaTemplate.builder("assertThat(#{any()}).isNotCloseTo(#{any()}, within(#{any()}));") + .contextSensitive() .staticImports("org.assertj.core.api.Assertions.assertThat", "org.assertj.core.api.Assertions.within") .javaParser(assertionsParser(ctx)) .build() @@ -123,6 +126,7 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu JavaTemplate.builder("assertThat(#{any()}).as(#{any(java.util.function.Supplier)}).isNotCloseTo(#{any()}, within(#{any()}));"); method = template + .contextSensitive() .staticImports("org.assertj.core.api.Assertions.assertThat", "org.assertj.core.api.Assertions.within") .javaParser(assertionsParser(ctx)) .build() diff --git a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertNotNullToAssertThat.java b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertNotNullToAssertThat.java index 04d8c80b3..2fde4a8b4 100644 --- a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertNotNullToAssertThat.java +++ b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertNotNullToAssertThat.java @@ -71,6 +71,7 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu if (args.size() == 1) { method = JavaTemplate.builder("assertThat(#{any()}).isNotNull();") + .contextSensitive() .staticImports("org.assertj.core.api.Assertions.assertThat") .javaParser(assertionsParser(ctx)) .build() @@ -88,6 +89,7 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu JavaTemplate.builder("assertThat(#{any()}).as(#{any(java.util.function.Supplier)}).isNotNull();"); method = template + .contextSensitive() .staticImports("org.assertj.core.api.Assertions.assertThat") .javaParser(assertionsParser(ctx)) .build() diff --git a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertNullToAssertThat.java b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertNullToAssertThat.java index 98b76e009..e94af26e1 100644 --- a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertNullToAssertThat.java +++ b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertNullToAssertThat.java @@ -71,6 +71,7 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu if (args.size() == 1) { method = JavaTemplate.builder("assertThat(#{any()}).isNull();") + .contextSensitive() .staticImports("org.assertj.core.api.Assertions.assertThat") .javaParser(assertionsParser(ctx)) .build() @@ -87,6 +88,7 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu JavaTemplate.builder("assertThat(#{any()}).as(#{any(java.util.function.Supplier)}).isNull();"); method = template + .contextSensitive() .staticImports("org.assertj.core.api.Assertions.assertThat") .javaParser(assertionsParser(ctx)) .build() diff --git a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertSameToAssertThat.java b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertSameToAssertThat.java index 4c710fc04..67b4e4d83 100644 --- a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertSameToAssertThat.java +++ b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertSameToAssertThat.java @@ -72,6 +72,7 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu if (args.size() == 2) { method = JavaTemplate.builder("assertThat(#{any()}).isSameAs(#{any()});") + .contextSensitive() .staticImports("org.assertj.core.api.Assertions.assertThat") .javaParser(assertionsParser(ctx)) .build() @@ -89,6 +90,7 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu JavaTemplate.builder("assertThat(#{any()}).as(#{any(java.util.function.Supplier)}).isSameAs(#{any()});"); method = template + .contextSensitive() .staticImports("org.assertj.core.api.Assertions.assertThat") .javaParser(assertionsParser(ctx)) .build() diff --git a/src/test/java/org/openrewrite/java/testing/assertj/JUnitAssertNotNullToAssertThatTest.java b/src/test/java/org/openrewrite/java/testing/assertj/JUnitAssertNotNullToAssertThatTest.java index 2428a80e5..8a4bcb5ec 100644 --- a/src/test/java/org/openrewrite/java/testing/assertj/JUnitAssertNotNullToAssertThatTest.java +++ b/src/test/java/org/openrewrite/java/testing/assertj/JUnitAssertNotNullToAssertThatTest.java @@ -18,6 +18,7 @@ import org.junit.jupiter.api.Test; import org.openrewrite.DocumentExample; import org.openrewrite.InMemoryExecutionContext; +import org.openrewrite.Issue; import org.openrewrite.java.JavaParser; import org.openrewrite.test.RecipeSpec; import org.openrewrite.test.RewriteTest; @@ -241,4 +242,44 @@ private String notification() { ) ); } + + @Test + @Issue("https://github.com/openrewrite/rewrite-testing-frameworks/issues/491") + void importAddedForCustomArguments() { + rewriteRun( + //language=java + java( + """ + import org.junit.jupiter.api.Test; + + import static org.junit.jupiter.api.Assertions.assertNotNull; + + class TTest { + + class A {} + + @Test + public void testClass() { + assertNotNull(new A()); + } + } + """, + """ + import org.junit.jupiter.api.Test; + + import static org.assertj.core.api.Assertions.assertThat; + + class TTest { + + class A {} + + @Test + public void testClass() { + assertThat(new A()).isNotNull(); + } + } + """ + ) + ); + } }