diff --git a/src/main/java/org/openrewrite/FindCallGraph.java b/src/main/java/org/openrewrite/FindCallGraph.java index 017ccf3..f8a49ce 100644 --- a/src/main/java/org/openrewrite/FindCallGraph.java +++ b/src/main/java/org/openrewrite/FindCallGraph.java @@ -21,6 +21,7 @@ import org.openrewrite.java.tree.J; import org.openrewrite.java.tree.JavaType; import org.openrewrite.marker.Markup; +import org.openrewrite.marker.SourceSet; import org.openrewrite.table.CallGraph; import java.util.*; @@ -94,23 +95,28 @@ private T recordCall(T j, ExecutionContext ctx) { return j; } Cursor scope = getCursor().dropParentUntil(it -> it instanceof J.MethodDeclaration || it instanceof J.ClassDeclaration || it instanceof SourceFile); + String sourceSet = Optional.ofNullable(scope.firstEnclosing(SourceFile.class)) + .map(Tree::getMarkers) + .flatMap(m -> m.findFirst(SourceSet.class)) + .map(SourceSet::getName) + .orElse("unknown"); if (scope.getValue() instanceof J.ClassDeclaration) { boolean isInStaticInitializer = inStaticInitializer(); if ((isInStaticInitializer && scope.computeMessageIfAbsent("METHODS_CALLED_IN_STATIC_INITIALIZATION", k -> new HashSet<>()).add(method)) || (!isInStaticInitializer && scope.computeMessageIfAbsent("METHODS_CALLED_IN_INSTANCE_INITIALIZATION", k -> new HashSet<>()).add(method))) { - callGraph.insertRow(ctx, row(requireNonNull(((J.ClassDeclaration) scope.getValue()).getType()).getFullyQualifiedName(), method)); + callGraph.insertRow(ctx, row(sourceSet, requireNonNull(((J.ClassDeclaration) scope.getValue()).getType()).getFullyQualifiedName(), method)); } } else if (scope.getValue() instanceof J.MethodDeclaration) { Set methodsCalledInScope = scope.computeMessageIfAbsent("METHODS_CALLED_IN_SCOPE", k -> new HashSet<>()); if (methodsCalledInScope.add(method)) { - callGraph.insertRow(ctx, row(requireNonNull(((J.MethodDeclaration) scope.getValue()).getMethodType()), method)); + callGraph.insertRow(ctx, row(sourceSet,requireNonNull(((J.MethodDeclaration) scope.getValue()).getMethodType()), method)); } } else if (scope.getValue() instanceof SourceFile) { // In Java there has to be a class declaration, but that isn't the case in Groovy/Kotlin/etc. // So we'll just use the source file path instead Set methodsCalledInScope = scope.computeMessageIfAbsent("METHODS_CALLED_IN_SCOPE", k -> new HashSet<>()); if (methodsCalledInScope.add(method)) { - callGraph.insertRow(ctx, row(((SourceFile) scope.getValue()).getSourcePath().toString(), method)); + callGraph.insertRow(ctx, row(sourceSet, ((SourceFile) scope.getValue()).getSourcePath().toString(), method)); } } return j; @@ -139,8 +145,9 @@ private boolean inStaticInitializer() { return inStaticInitializer.get(); } - private CallGraph.Row row(String fqn, JavaType.Method to) { + private CallGraph.Row row(String sourceSet, String fqn, JavaType.Method to) { return new CallGraph.Row( + sourceSet, fqn, inStaticInitializer() ? "" : "", "", @@ -154,8 +161,9 @@ private CallGraph.Row row(String fqn, JavaType.Method to) { ); } - private CallGraph.Row row(JavaType.Method from, JavaType.Method to) { + private CallGraph.Row row(String sourceSet,JavaType.Method from, JavaType.Method to) { return new CallGraph.Row( + sourceSet, from.getDeclaringType().getFullyQualifiedName(), from.getName(), parameters(from), diff --git a/src/main/java/org/openrewrite/table/CallGraph.java b/src/main/java/org/openrewrite/table/CallGraph.java index f6219cf..67a0a7b 100644 --- a/src/main/java/org/openrewrite/table/CallGraph.java +++ b/src/main/java/org/openrewrite/table/CallGraph.java @@ -31,6 +31,11 @@ public CallGraph(Recipe recipe) { @Value public static class Row { + + @Column(displayName = "From source set", + description = "The source set from which the action is issued.") + String fromSourceSet; + @Column(displayName = "From class", description = "The fully qualified name of the class from which the action is issued.") String fromClass; diff --git a/src/test/java/org/openrewrite/FindCallGraphTest.java b/src/test/java/org/openrewrite/FindCallGraphTest.java index cac007c..958128f 100644 --- a/src/test/java/org/openrewrite/FindCallGraphTest.java +++ b/src/test/java/org/openrewrite/FindCallGraphTest.java @@ -22,7 +22,7 @@ import org.openrewrite.test.TypeValidation; import static org.assertj.core.api.Assertions.assertThat; -import static org.openrewrite.java.Assertions.java; +import static org.openrewrite.java.Assertions.*; import static org.openrewrite.kotlin.Assertions.kotlin; @SuppressWarnings({"UnusedAssignment", "DataFlowIssue", "InfiniteRecursion"}) @@ -39,6 +39,7 @@ void findUniqueCallsPerDeclaration() { spec -> spec.dataTable(CallGraph.Row.class, row -> assertThat(row).containsExactly( new CallGraph.Row( + "main", "Test", "test", "", @@ -51,6 +52,7 @@ void findUniqueCallsPerDeclaration() { "void" ), new CallGraph.Row( + "main", "Test", "test2", "", @@ -64,9 +66,9 @@ void findUniqueCallsPerDeclaration() { ) ) ), - //language=java - java( - """ + mavenProject("project", srcMainJava( + //language=java + java(""" class Test { void test() { System.out.println("Hello"); @@ -79,7 +81,8 @@ void test2() { } } """ - ) + ) + )) ); } @@ -91,6 +94,7 @@ void filterStdLib() { .dataTable(CallGraph.Row.class, row -> assertThat(row).containsExactly( new CallGraph.Row( + "unknown", "Test", "test", "", @@ -104,8 +108,7 @@ void filterStdLib() { ) )), //language=java - java( - """ + java(""" import java.util.List; import java.util.ArrayList; class Test { @@ -126,6 +129,7 @@ void staticInitializer() { spec -> spec.dataTable(CallGraph.Row.class, row -> assertThat(row).containsExactly( new CallGraph.Row( + "unknown", "Scratch", "", "", @@ -138,6 +142,7 @@ void staticInitializer() { "int" ), new CallGraph.Row( + "unknown", "Scratch", "", "", @@ -152,8 +157,7 @@ void staticInitializer() { ) ), //language=java - java( - """ + java(""" class Scratch { static int i = bar(); static { @@ -172,6 +176,7 @@ void initializer() { spec -> spec.dataTable(CallGraph.Row.class, row -> assertThat(row).containsExactly( new CallGraph.Row( + "unknown", "Scratch", "", "", @@ -184,6 +189,7 @@ void initializer() { "int" ), new CallGraph.Row( + "unknown", "Scratch", "", "", @@ -198,8 +204,7 @@ void initializer() { ) ), //language=java - java( - """ + java(""" class Scratch { int i = bar(); int j; @@ -220,6 +225,7 @@ void innerClass() { spec -> spec.dataTable(CallGraph.Row.class, row -> assertThat(row).containsExactly( new CallGraph.Row( + "unknown", "A$B", "b", "", @@ -232,6 +238,7 @@ void innerClass() { "A$C" ), new CallGraph.Row( + "unknown", "A$B", "b", "", @@ -246,8 +253,7 @@ void innerClass() { ) ), //language=java - java( - """ + java(""" class A { class B { void b() { @@ -270,6 +276,7 @@ void anonymousClass() { spec -> spec.dataTable(CallGraph.Row.class, row -> assertThat(row).contains( new CallGraph.Row( + "unknown", "B", "call", "", @@ -282,6 +289,7 @@ void anonymousClass() { "void" ), new CallGraph.Row( + "unknown", "B", "", "", @@ -294,6 +302,7 @@ void anonymousClass() { "B$1" ), new CallGraph.Row( + "unknown", "B$1", "method", "", @@ -308,8 +317,7 @@ void anonymousClass() { ) ), //language=java - java( - """ + java(""" class A { public void method() {} } @@ -334,6 +342,7 @@ void companionObject() { spec -> spec.dataTable(CallGraph.Row.class, row -> assertThat(row).containsExactly( new CallGraph.Row( + "unknown", "A$Companion", "main", "kotlin.Array", @@ -348,8 +357,7 @@ void companionObject() { ) ), //language=kotlin - kotlin( - """ + kotlin(""" class A { companion object { @JvmStatic @@ -368,8 +376,7 @@ void missingMethodMarked() { rewriteRun( spec -> spec.typeValidationOptions(TypeValidation.none()), //language=java - java( - """ + java(""" class A { String s = foo(); } @@ -388,6 +395,7 @@ void fieldDeclarationInitialization() { spec -> spec.dataTable(CallGraph.Row.class, row -> assertThat(row).containsExactly( new CallGraph.Row( + "unknown", "A", "", "", @@ -400,6 +408,7 @@ void fieldDeclarationInitialization() { "java.lang.String" ), new CallGraph.Row( + "unknown", "A", "", "", @@ -414,8 +423,7 @@ void fieldDeclarationInitialization() { ) ), //language=java - java( - """ + java(""" class A { String instanceField = foo(); static String staticField = foo(); @@ -431,6 +439,7 @@ void initializerBlocks() { spec -> spec.dataTable(CallGraph.Row.class, row -> assertThat(row).containsExactly( new CallGraph.Row( + "unknown", "A", "", "", @@ -443,6 +452,7 @@ void initializerBlocks() { "java.lang.String" ), new CallGraph.Row( + "unknown", "A", "", "", @@ -457,8 +467,7 @@ void initializerBlocks() { ) ), //language=java - java( - """ + java(""" class A { String instanceField; {