diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/data/TestGenerationData.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/data/TestGenerationData.kt index 47fe7c678..816cad956 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/data/TestGenerationData.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/data/TestGenerationData.kt @@ -14,7 +14,7 @@ data class TestGenerationData( // Code required of imports and package for generated tests var importsCode: MutableSet = mutableSetOf(), var packageName: String = "", - var runWith: String = "", + var annotation: String = "", // Modifications to this code in the tool-window editor are forgotten when apply to test suite var otherInfo: String = "", // changing parameters with a large prompt @@ -30,7 +30,7 @@ data class TestGenerationData( fileUrl = "" importsCode = mutableSetOf() packageName = "" - runWith = "" + annotation = "" otherInfo = "" polyDepthReducing = 0 inputParamsDepthReducing = 0 diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/TestSuiteGeneratedByLLM.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/TestSuiteGeneratedByLLM.kt index 8d4456b3a..9613d4c7e 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/TestSuiteGeneratedByLLM.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/data/TestSuiteGeneratedByLLM.kt @@ -10,7 +10,7 @@ package org.jetbrains.research.testspark.core.test.data data class TestSuiteGeneratedByLLM( var imports: MutableSet = mutableSetOf(), var packageName: String = "", - var runWith: String = "", + var annotation: String = "", var otherInfo: String = "", var testCases: MutableList = mutableListOf(), ) { diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/java/JavaJUnitTestSuiteParser.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/java/JavaJUnitTestSuiteParser.kt index 279badc57..4cbdcb72c 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/java/JavaJUnitTestSuiteParser.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/java/JavaJUnitTestSuiteParser.kt @@ -22,7 +22,6 @@ class JavaJUnitTestSuiteParser( return JUnitTestSuiteParserStrategy.parseJUnitTestSuite( rawText, - junitVersion, javaImportPattern, packageName, testNamePattern = "void", diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/kotlin/KotlinJUnitTestSuiteParser.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/kotlin/KotlinJUnitTestSuiteParser.kt index 18b164810..c3ad9ad82 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/kotlin/KotlinJUnitTestSuiteParser.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/kotlin/KotlinJUnitTestSuiteParser.kt @@ -22,7 +22,6 @@ class KotlinJUnitTestSuiteParser( return JUnitTestSuiteParserStrategy.parseJUnitTestSuite( rawText, - junitVersion, kotlinImportPattern, packageName, testNamePattern = "fun", diff --git a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/strategies/JUnitTestSuiteParserStrategy.kt b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/strategies/JUnitTestSuiteParserStrategy.kt index f7b905dda..a1b96f967 100644 --- a/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/strategies/JUnitTestSuiteParserStrategy.kt +++ b/core/src/main/kotlin/org/jetbrains/research/testspark/core/test/strategies/JUnitTestSuiteParserStrategy.kt @@ -21,7 +21,6 @@ class JUnitTestSuiteParserStrategy { companion object { fun parseJUnitTestSuite( rawText: String, - junitVersion: JUnitVersion, importPattern: Regex, packageName: String, testNamePattern: String, @@ -41,8 +40,9 @@ class JUnitTestSuiteParserStrategy { .map { it.groupValues[0] } .toMutableSet() - // save RunWith - val runWith: String = junitVersion.runWithAnnotationMeta.extract(rawCode) ?: "" + // save ExtendWith or RunWith annotation if present + val runWithAnnotation: String = JUnitVersion.JUnit4.runWithAnnotationMeta.extract(rawCode) ?: "" + val annotation = JUnitVersion.JUnit5.runWithAnnotationMeta.extract(rawCode) ?: runWithAnnotation val testSet: MutableList = rawCode.split("@Test").toMutableList() @@ -82,7 +82,7 @@ class JUnitTestSuiteParserStrategy { TestSuiteGeneratedByLLM( imports = imports, packageName = packageName, - runWith = runWith, + annotation = annotation, otherInfo = otherInfo, testCases = testCases, ) diff --git a/src/main/kotlin/org/jetbrains/research/testspark/display/utils/java/JavaDisplayUtils.kt b/src/main/kotlin/org/jetbrains/research/testspark/display/utils/java/JavaDisplayUtils.kt index 9e894abe8..6205b688b 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/display/utils/java/JavaDisplayUtils.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/display/utils/java/JavaDisplayUtils.kt @@ -22,6 +22,7 @@ import org.jetbrains.research.testspark.display.utils.ErrorMessageManager import org.jetbrains.research.testspark.display.utils.template.DisplayUtils import org.jetbrains.research.testspark.java.JavaPsiClassWrapper import org.jetbrains.research.testspark.langwrappers.PsiClassWrapper +import org.jetbrains.research.testspark.services.LLMSettingsService import org.jetbrains.research.testspark.testmanager.java.JavaTestAnalyzer import org.jetbrains.research.testspark.testmanager.java.JavaTestGenerator import java.io.File @@ -141,8 +142,11 @@ class JavaDisplayUtils : DisplayUtils { psiJavaFile = (PsiManager.getInstance(project).findFile(virtualFile!!) as PsiJavaFile) psiClass = PsiElementFactory.getInstance(project).createClass(className.split(".")[0]) - if (uiContext!!.testGenerationOutput.runWith.isNotEmpty()) { - psiClass!!.modifierList!!.addAnnotation("RunWith(${uiContext.testGenerationOutput.runWith})") + if (uiContext!!.testGenerationOutput.annotation.isNotEmpty()) { + val junitVersion = project.getService(LLMSettingsService::class.java).state.junitVersion + psiClass!!.modifierList!!.addAnnotation( + "${junitVersion.runWithAnnotationMeta.annotationName}(${uiContext.testGenerationOutput.annotation})", + ) } psiJavaFile!!.add(psiClass!!) diff --git a/src/main/kotlin/org/jetbrains/research/testspark/display/utils/kotlin/KotlinDisplayUtils.kt b/src/main/kotlin/org/jetbrains/research/testspark/display/utils/kotlin/KotlinDisplayUtils.kt index bb6a78c66..59c3567f7 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/display/utils/kotlin/KotlinDisplayUtils.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/display/utils/kotlin/KotlinDisplayUtils.kt @@ -24,6 +24,7 @@ import org.jetbrains.research.testspark.display.utils.ErrorMessageManager import org.jetbrains.research.testspark.display.utils.template.DisplayUtils import org.jetbrains.research.testspark.kotlin.KotlinPsiClassWrapper import org.jetbrains.research.testspark.langwrappers.PsiClassWrapper +import org.jetbrains.research.testspark.services.LLMSettingsService import org.jetbrains.research.testspark.testmanager.kotlin.KotlinTestAnalyzer import org.jetbrains.research.testspark.testmanager.kotlin.KotlinTestGenerator import java.io.File @@ -144,9 +145,12 @@ class KotlinDisplayUtils : DisplayUtils { val ktPsiFactory = KtPsiFactory(project) ktClass = ktPsiFactory.createClass("class ${className.split(".")[0]} {}") - if (uiContext!!.testGenerationOutput.runWith.isNotEmpty()) { + if (uiContext!!.testGenerationOutput.annotation.isNotEmpty()) { + val junitVersion = project.getService(LLMSettingsService::class.java).state.junitVersion val annotationEntry = - ktPsiFactory.createAnnotationEntry("@RunWith(${uiContext.testGenerationOutput.runWith})") + ktPsiFactory.createAnnotationEntry( + "@${junitVersion.runWithAnnotationMeta.annotationName}(${uiContext.testGenerationOutput.annotation})", + ) ktClass!!.addBefore(annotationEntry, ktClass!!.body) } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/testmanager/java/JavaTestGenerator.kt b/src/main/kotlin/org/jetbrains/research/testspark/testmanager/java/JavaTestGenerator.kt index 55ffa62c5..31beef72b 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/testmanager/java/JavaTestGenerator.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/testmanager/java/JavaTestGenerator.kt @@ -9,6 +9,7 @@ import com.intellij.psi.PsiFile import com.intellij.psi.PsiFileFactory import com.intellij.psi.codeStyle.CodeStyleManager import org.jetbrains.research.testspark.core.data.TestGenerationData +import org.jetbrains.research.testspark.services.LLMSettingsService import org.jetbrains.research.testspark.testmanager.template.TestGenerator import java.io.File @@ -21,11 +22,11 @@ object JavaTestGenerator : TestGenerator { body: String, imports: Set, packageString: String, - runWith: String, + annotation: String, otherInfo: String, testGenerationData: TestGenerationData, ): String { - var testFullText = printUpperPart(className, imports, packageString, runWith, otherInfo) + var testFullText = printUpperPart(className, imports, packageString, annotation, otherInfo, project) // Add each test (exclude expected exception) testFullText += body @@ -75,8 +76,9 @@ object JavaTestGenerator : TestGenerator { className: String, imports: Set, packageString: String, - runWith: String, + annotation: String, otherInfo: String, + project: Project, ): String { var testText = "" @@ -92,9 +94,10 @@ object JavaTestGenerator : TestGenerator { testText += "\n" - // add runWith if exists - if (runWith.isNotBlank()) { - testText += "@RunWith($runWith)\n" + // add RunWith or ExtendWith annotation if exists + if (annotation.isNotBlank()) { + val junitVersion = project.getService(LLMSettingsService::class.java).state.junitVersion + testText += "@${junitVersion.runWithAnnotationMeta.annotationName}($annotation)\n" } // open the test class testText += "public class $className {\n\n" diff --git a/src/main/kotlin/org/jetbrains/research/testspark/testmanager/kotlin/KotlinTestGenerator.kt b/src/main/kotlin/org/jetbrains/research/testspark/testmanager/kotlin/KotlinTestGenerator.kt index 9900de250..b00dd24bb 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/testmanager/kotlin/KotlinTestGenerator.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/testmanager/kotlin/KotlinTestGenerator.kt @@ -9,6 +9,7 @@ import com.intellij.psi.PsiFileFactory import com.intellij.psi.codeStyle.CodeStyleManager import org.jetbrains.kotlin.idea.KotlinLanguage import org.jetbrains.research.testspark.core.data.TestGenerationData +import org.jetbrains.research.testspark.services.LLMSettingsService import org.jetbrains.research.testspark.testmanager.template.TestGenerator import java.io.File @@ -21,14 +22,14 @@ object KotlinTestGenerator : TestGenerator { body: String, imports: Set, packageString: String, - runWith: String, + annotation: String, otherInfo: String, testGenerationData: TestGenerationData, ): String { log.debug("[KotlinClassBuilderHelper] Generate code for $className") var testFullText = - printUpperPart(className, imports, packageString, runWith, otherInfo) + printUpperPart(className, imports, packageString, annotation, otherInfo, project) // Add each test (exclude expected exception) testFullText += body @@ -71,8 +72,9 @@ object KotlinTestGenerator : TestGenerator { className: String, imports: Set, packageString: String, - runWith: String, + annotation: String, otherInfo: String, + project: Project, ): String { var testText = "" @@ -88,9 +90,10 @@ object KotlinTestGenerator : TestGenerator { testText += "\n" - // Add runWith if exists - if (runWith.isNotBlank()) { - testText += "@RunWith($runWith::class)\n" + // Add ExtendWith or RunWith annotation if exists + if (annotation.isNotBlank()) { + val junitVersion = project.getService(LLMSettingsService::class.java).state.junitVersion + testText += "@${junitVersion.runWithAnnotationMeta.annotationName}($annotation::class)\n" } // Open the test class diff --git a/src/main/kotlin/org/jetbrains/research/testspark/testmanager/template/TestGenerator.kt b/src/main/kotlin/org/jetbrains/research/testspark/testmanager/template/TestGenerator.kt index e92c1a51e..d5f4e99e7 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/testmanager/template/TestGenerator.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/testmanager/template/TestGenerator.kt @@ -15,7 +15,7 @@ interface TestGenerator { * @param body the body of the test class * @param imports the set of imports needed in the test class * @param packageString the package declaration of the test class - * @param runWith the runWith annotation for the test class + * @param annotation the RunWith or ExtendWith annotation for the test class * @param otherInfo any other additional information for the test class * @param testGenerationData the data used for test generation * @return the generated code as a string @@ -26,7 +26,7 @@ interface TestGenerator { body: String, imports: Set, packageString: String, - runWith: String, + annotation: String, otherInfo: String, testGenerationData: TestGenerationData, ): String diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/ToolUtils.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/ToolUtils.kt index 3015207fd..6f34ed91a 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/ToolUtils.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/ToolUtils.kt @@ -59,7 +59,7 @@ object ToolUtils { code, generatedTestData.importsCode, generatedTestData.packageName, - generatedTestData.runWith, + generatedTestData.annotation, generatedTestData.otherInfo, generatedTestData, ) diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/JUnitTestsAssembler.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/JUnitTestsAssembler.kt index 08d2c01aa..ae8ede48c 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/JUnitTestsAssembler.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/generation/JUnitTestsAssembler.kt @@ -52,11 +52,11 @@ class JUnitTestsAssembler( val testSuite = testSuiteParser.parseTestSuite(super.getContent()) // save RunWith - if (testSuite?.runWith?.isNotBlank() == true) { - generationData.runWith = testSuite.runWith + if (testSuite?.annotation?.isNotBlank() == true) { + generationData.annotation = testSuite.annotation generationData.importsCode.add(junitVersion.runWithAnnotationMeta.import) } else { - generationData.runWith = "" + generationData.annotation = "" generationData.importsCode.remove(junitVersion.runWithAnnotationMeta.import) } diff --git a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/test/JUnitTestSuitePresenter.kt b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/test/JUnitTestSuitePresenter.kt index f45e2f172..8b3d68e4a 100644 --- a/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/test/JUnitTestSuitePresenter.kt +++ b/src/main/kotlin/org/jetbrains/research/testspark/tools/llm/test/JUnitTestSuitePresenter.kt @@ -42,7 +42,7 @@ class JUnitTestSuitePresenter( testBody, imports, packageName, - runWith, + annotation, otherInfo, generatedTestsData, ) @@ -65,7 +65,7 @@ class JUnitTestSuitePresenter( testCases[testCaseIndex].toStringWithoutExpectedException() + "\n", imports, packageName, - runWith, + annotation, otherInfo, generatedTestsData, ) @@ -89,7 +89,7 @@ class JUnitTestSuitePresenter( testBody, imports, packageName, - runWith, + annotation, otherInfo, generatedTestsData, )