diff --git a/src/main/java/org/openrewrite/java/migrate/lang/var/UseVarForGenericMethodInvocations.java b/src/main/java/org/openrewrite/java/migrate/lang/var/UseVarForGenericMethodInvocations.java index 82fbf11806..af4a19387a 100644 --- a/src/main/java/org/openrewrite/java/migrate/lang/var/UseVarForGenericMethodInvocations.java +++ b/src/main/java/org/openrewrite/java/migrate/lang/var/UseVarForGenericMethodInvocations.java @@ -15,9 +15,6 @@ */ package org.openrewrite.java.migrate.lang.var; -import java.util.ArrayList; -import java.util.List; - import org.openrewrite.*; import org.openrewrite.java.JavaIsoVisitor; import org.openrewrite.java.JavaParser; @@ -26,6 +23,9 @@ import org.openrewrite.java.tree.*; import org.openrewrite.marker.Markers; +import java.util.ArrayList; +import java.util.List; + import static java.util.Collections.emptyList; public class UseVarForGenericMethodInvocations extends Recipe { @@ -51,6 +51,7 @@ public TreeVisitor getVisitor() { static final class UseVarForGenericsVisitor extends JavaIsoVisitor { private final JavaTemplate template = JavaTemplate.builder("var #{} = #{any()}") + .contextSensitive() .javaParser(JavaParser.fromJavaVersion()).build(); @Override @@ -77,9 +78,29 @@ public J.VariableDeclarations visitVariableDeclarations(J.VariableDeclarations v if (hasNoTypeParams && argumentsEmpty) return vd; // mark imports for removal if unused - if (vd.getType() instanceof JavaType.FullyQualified) maybeRemoveImport((JavaType.FullyQualified) vd.getType()); + JavaType typeRemoved = vd.getType(); + if (typeRemoved instanceof JavaType.FullyQualified) { + if (typeRemoved instanceof JavaType.Parameterized) { // parameterized have to be decomposed + typeRemoved = ((JavaType.Parameterized) typeRemoved).getType(); + } + maybeRemoveImport((JavaType.FullyQualified) typeRemoved); + } - return transformToVar(vd, new ArrayList<>(), new ArrayList<>()); + + //determine types + List leftTypes = new ArrayList<>(); + TypeTree leftTypeExpression = vd.getTypeExpression(); + if (leftTypeExpression != null) { + leftTypes.add(leftTypeExpression.getType()); + } + + List rightTypes = new ArrayList<>(); + JavaType initializerType = initializer.getType(); + if (initializerType != null) { + rightTypes.add(initializerType); + } + + return transformToVar(vd, leftTypes, rightTypes); } private static boolean allArgumentsEmpty(J.MethodInvocation invocation) { @@ -100,10 +121,13 @@ private J.VariableDeclarations transformToVar(J.VariableDeclarations vd, List typeArgument = new ArrayList<>(); for (JavaType t : leftTypes) { - typeArgument.add(new J.Identifier(Tree.randomId(), Space.EMPTY, Markers.EMPTY, emptyList(), ((JavaType.Class) t).getClassName(), t, null)); + typeArgument.add(new J.Identifier(Tree.randomId(), Space.EMPTY, Markers.EMPTY, emptyList(), ((JavaType.FullyQualified) t).getClassName(), t, null)); + } + + if (initializer instanceof J.NewClass) { // for constructor invocations we need to handle generics + J.ParameterizedType typedInitializerClazz = ((J.ParameterizedType) ((J.NewClass) initializer).getClazz()).withTypeParameters(typeArgument); + initializer = ((J.NewClass) initializer).withClazz(typedInitializerClazz); } - J.ParameterizedType typedInitializerClazz = ((J.ParameterizedType) ((J.NewClass) initializer).getClazz()).withTypeParameters(typeArgument); - initializer = ((J.NewClass) initializer).withClazz(typedInitializerClazz); } J.VariableDeclarations result = template.apply(getCursor(), vd.getCoordinates().replace(), simpleName, initializer) diff --git a/src/test/java/org/openrewrite/java/migrate/lang/UseVarKeywordTest.java b/src/test/java/org/openrewrite/java/migrate/lang/UseVarKeywordTest.java index 52b79b0631..f003868e7f 100644 --- a/src/test/java/org/openrewrite/java/migrate/lang/UseVarKeywordTest.java +++ b/src/test/java/org/openrewrite/java/migrate/lang/UseVarKeywordTest.java @@ -15,9 +15,6 @@ */ package org.openrewrite.java.migrate.lang; -import static org.openrewrite.java.Assertions.java; -import static org.openrewrite.java.Assertions.version; - import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; @@ -26,6 +23,9 @@ import org.openrewrite.test.RecipeSpec; import org.openrewrite.test.RewriteTest; +import static org.openrewrite.java.Assertions.java; +import static org.openrewrite.java.Assertions.version; + class UseVarKeywordTest implements RewriteTest { @Override @@ -37,6 +37,233 @@ public void defaults(RecipeSpec spec) { .activateRecipes("org.openrewrite.java.migrate.lang.UseVar")); } + @Nested + class BugFixing { + + @Test + void anonymousClass() { + // spring-projects/spring-data-commons @ main: src/test/java/org/springframework/data/domain/ManagedTypesUnitTests.java + // solving: Expected a template that would generate exactly one statement to replace one statement, but generated 2. Template: + //var typesSupplier = __P__./*__p1__*/p() + //language=java + rewriteRun( + version( + java(""" + package com.example.app; + + import java.util.Collections; + import java.util.function.Supplier; + + class ManagedTypesUnitTests { + void supplierBasedManagedTypesAreEvaluatedLazily() { + Supplier>> typesSupplier = spy(new Supplier>>() { + @Override + public Iterable> get() { + return Collections.singleton(Object.class); + } + }); + } + + // mock for mockito method + private Supplier>> spy(Supplier>> supplier) { + return null; + } + } + """, """ + package com.example.app; + + import java.util.Collections; + import java.util.function.Supplier; + + class ManagedTypesUnitTests { + void supplierBasedManagedTypesAreEvaluatedLazily() { + var typesSupplier = spy(new Supplier>>() { + @Override + public Iterable> get() { + return Collections.singleton(Object.class); + } + }); + } + + // mock for mockito method + private Supplier>> spy(Supplier>> supplier) { + return null; + } + } + """), + 10 + ) + ); + } + + @Test + void multiGenerics() { + // spring-cloud/spring-cloud-contract @ main: spring-cloud-contract-verifier/src/test/resources/contractsToCompile/contract_multipart.java + // solving java.lang.IllegalArgumentException: Unable to parse expression from JavaType Unknown + //language=java + rewriteRun( + version( + java(""" + import java.util.Collection; + import java.util.HashMap; + import java.util.Map; + import java.util.function.Supplier; + + class contract_multipart implements Supplier> { \s + private static Map namedProps(HttpSender.Request r) { + Map map = new HashMap<>(); + return map; + } + + @Override + public Collection get() { return null; } + } + // replacements + class Contract{} + class DslProperty{} + class HttpSender { + static class Request {} + } + """, """ + import java.util.Collection; + import java.util.HashMap; + import java.util.Map; + import java.util.function.Supplier; + + class contract_multipart implements Supplier> { \s + private static Map namedProps(HttpSender.Request r) { + var map = new HashMap(); + return map; + } + + @Override + public Collection get() { return null; } + } + // replacements + class Contract{} + class DslProperty{} + class HttpSender { + static class Request {} + } + """), + 10 + ) + ); + } + + @Test + void duplicateTemplate() { + // spring-projects/spring-hateoas @ main src/test/java/org/springframework/hateoas/mediatype/html/HtmlInputTypeUnitTests.java + // solving Expected a template that would generate exactly one statement to replace one statement, but generated 2. Template: + //var numbers = __P__.>/*__p1__*/p() + //language=java + rewriteRun( + version( + java(""" + import java.math.BigDecimal; + import java.util.Arrays; + import java.util.Collection; + import java.util.stream.Stream; + + class HtmlInputTypeUnitTests { + Stream derivesInputTypesFromType() { + Stream<$> numbers = HtmlInputType.NUMERIC_TYPES.stream() // + .map(it -> $.of(it, HtmlInputType.NUMBER)); + return null; + } + + static class HtmlInputType { + static final Collection> NUMERIC_TYPES = Arrays.asList(int.class, long.class, float.class, + double.class, short.class, Integer.class, Long.class, Float.class, Double.class, Short.class, BigDecimal.class); + + public static final HtmlInputType NUMBER = new HtmlInputType(); + + public static HtmlInputType from(Class type) { return null; } + } + + static class $ { + + Class type; + HtmlInputType expected; + + static $ of(Class it, HtmlInputType number){ return null; }\s + + public void verify() { + assertThat(HtmlInputType.from(type)).isEqualTo(expected); + } + + @Override + public String toString() { + return String.format("Derives %s from %s.", expected, type); + } + //mocking + private > AbstractBigDecimalAssert assertThat(HtmlInputType from) { + return null; + } + } + } + // replacement + class DynamicTest {} + class AbstractBigDecimalAssert { + public void isEqualTo(Object expected) {} + } + """, """ + import java.math.BigDecimal; + import java.util.Arrays; + import java.util.Collection; + import java.util.stream.Stream; + + import static org.assertj.core.api.Assertions.assertThat; + + class HtmlInputTypeUnitTests { + Stream derivesInputTypesFromType() { + var numbers = HtmlInputType.NUMERIC_TYPES.stream() // + .map(it -> $.of(it, HtmlInputType.NUMBER)); + return null; + } + + static class HtmlInputType { + static final Collection> NUMERIC_TYPES = Arrays.asList(int.class, long.class, float.class, + double.class, short.class, Integer.class, Long.class, Float.class, Double.class, Short.class, BigDecimal.class); + + public static final HtmlInputType NUMBER = new HtmlInputType(); + + public static HtmlInputType from(Class type) { return null; } + } + + static class $ { + + Class type; + HtmlInputType expected; + + static $ of(Class it, HtmlInputType number){ return null; }\s + + public void verify() { + assertThat(HtmlInputType.from(type)).isEqualTo(expected); + } + + @Override + public String toString() { + return String.format("Derives %s from %s.", expected, type); + } + //mocking + private > AbstractBigDecimalAssert assertThat(HtmlInputType from) { + return null; + } + } + } + // replacement + class DynamicTest {} + class AbstractBigDecimalAssert { + public void isEqualTo(Object expected) {} + } + """), + 10 + ) + ); + } + } + @Nested class GeneralNotApplicable { @@ -142,14 +369,14 @@ void withTernary() { rewriteRun( version( java(""" - package com.example.app; - - class A { - void m() { - String o = true ? "isTrue" : "Test"; - } - } - """), + package com.example.app; + + class A { + void m() { + String o = true ? "isTrue" : "Test"; + } + } + """), 10 ) ); diff --git a/src/test/java/org/openrewrite/java/migrate/lang/var/UseVarForGenericMethodInvocationsTest.java b/src/test/java/org/openrewrite/java/migrate/lang/var/UseVarForGenericMethodInvocationsTest.java index 8e0e9a3332..e2f140cba0 100644 --- a/src/test/java/org/openrewrite/java/migrate/lang/var/UseVarForGenericMethodInvocationsTest.java +++ b/src/test/java/org/openrewrite/java/migrate/lang/var/UseVarForGenericMethodInvocationsTest.java @@ -15,14 +15,14 @@ */ package org.openrewrite.java.migrate.lang.var; -import static org.openrewrite.java.Assertions.*; - import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.openrewrite.DocumentExample; import org.openrewrite.test.RecipeSpec; import org.openrewrite.test.RewriteTest; +import static org.openrewrite.java.Assertions.*; + public class UseVarForGenericMethodInvocationsTest implements RewriteTest { @Override public void defaults(RecipeSpec spec) { @@ -256,5 +256,111 @@ void m() { ) ); } + + @Test + void streamUsage() { + //language=java + rewriteRun( + version( + java(""" + import java.math.BigDecimal; + import java.util.Arrays; + import java.util.Collection; + import java.util.stream.Stream; + + class HtmlInputTypeUnitTests { + Stream derivesInputTypesFromType() { + Stream<$> numbers = HtmlInputType.NUMERIC_TYPES.stream().map(it -> $.of(it, HtmlInputType.NUMBER)); + return null; + } + + static class HtmlInputType { + static final Collection> NUMERIC_TYPES = Arrays.asList(int.class, long.class, float.class, + double.class, short.class, Integer.class, Long.class, Float.class, Double.class, Short.class, BigDecimal.class); + + public static final HtmlInputType NUMBER = new HtmlInputType(); + + public static HtmlInputType from(Class type) { return null; } + } + + static class $ { + + Class type; + HtmlInputType expected; + + static $ of(Class it, HtmlInputType number){ return null; } + + public void verify() { + assertThat(HtmlInputType.from(type)).isEqualTo(expected); + } + + @Override + public String toString() { + return String.format("Derives %s from %s.", expected, type); + } + //mocking + private > AbstractBigDecimalAssert assertThat(HtmlInputType from) { + return null; + } + } + } + // replacement + class DynamicTest {} + class AbstractBigDecimalAssert { + public void isEqualTo(Object expected) {} + } + """, """ + import java.math.BigDecimal; + import java.util.Arrays; + import java.util.Collection; + import java.util.stream.Stream; + + class HtmlInputTypeUnitTests { + Stream derivesInputTypesFromType() { + var numbers = HtmlInputType.NUMERIC_TYPES.stream() + .map(it -> $.of(it, HtmlInputType.NUMBER)); + return null; + } + + static class HtmlInputType { + static final Collection> NUMERIC_TYPES = Arrays.asList(int.class, long.class, float.class, + double.class, short.class, Integer.class, Long.class, Float.class, Double.class, Short.class, BigDecimal.class); + + public static final HtmlInputType NUMBER = new HtmlInputType(); + + public static HtmlInputType from(Class type) { return null; } + } + + static class $ { + + Class type; + HtmlInputType expected; + + static $ of(Class it, HtmlInputType number){ return null; } + + public void verify() { + assertThat(HtmlInputType.from(type)).isEqualTo(expected); + } + + @Override + public String toString() { + return String.format("Derives %s from %s.", expected, type); + } + //mocking + private > AbstractBigDecimalAssert assertThat(HtmlInputType from) { + return null; + } + } + } + // replacement + class DynamicTest {} + class AbstractBigDecimalAssert { + public void isEqualTo(Object expected) {} + } + """), + 10 + ) + ); + } } }