diff --git a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertEqualsToAssertThat.java b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertEqualsToAssertThat.java index e523006bb..70c122382 100644 --- a/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertEqualsToAssertThat.java +++ b/src/main/java/org/openrewrite/java/testing/assertj/JUnitAssertEqualsToAssertThat.java @@ -72,7 +72,8 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu Expression expected = args.get(0); Expression actual = args.get(1); - maybeAddImport("org.assertj.core.api.Assertions", "assertThat"); + //always add the import (even if not referenced) + maybeAddImport("org.assertj.core.api.Assertions", "assertThat", false); maybeRemoveImport("org.junit.jupiter.api.Assertions"); if (args.size() == 2) { diff --git a/src/test/java/org/openrewrite/java/testing/assertj/JUnitAssertEqualsToAssertThatTest.java b/src/test/java/org/openrewrite/java/testing/assertj/JUnitAssertEqualsToAssertThatTest.java index e0afdadad..4d2d52470 100644 --- a/src/test/java/org/openrewrite/java/testing/assertj/JUnitAssertEqualsToAssertThatTest.java +++ b/src/test/java/org/openrewrite/java/testing/assertj/JUnitAssertEqualsToAssertThatTest.java @@ -384,23 +384,23 @@ public record OwnClass(String a) { } } """, - """ + """ package org.example; - import static org.assertj.core.api.Assertions.assertThat; - - import org.junit.jupiter.api.Test; - - class ATest { + import org.junit.jupiter.api.Test; + + import static org.assertj.core.api.Assertions.assertThat; + + class ATest { - @Test void testEquals() { - assertThat(new OwnClass()).isEqualTo(new OwnClass()); - } + @Test void testEquals() { + assertThat(new OwnClass()).isEqualTo(new OwnClass()); + } - public record OwnClass(String a) { + public record OwnClass(String a) { - public OwnClass() {this("1");} - } + public OwnClass() {this("1");} + } } """ )