Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ data class TestGenerationData(
// Code required of imports and package for generated tests
var importsCode: MutableSet<String> = mutableSetOf(),
var packageName: String = "",
var runWith: String = "",
var annotation: String = "",
// Modifications to this code in the tool-window editor are forgotten when apply to test suite
var otherInfo: String = "",
// changing parameters with a large prompt
Expand All @@ -30,7 +30,7 @@ data class TestGenerationData(
fileUrl = ""
importsCode = mutableSetOf()
packageName = ""
runWith = ""
annotation = ""
otherInfo = ""
polyDepthReducing = 0
inputParamsDepthReducing = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ package org.jetbrains.research.testspark.core.test.data
data class TestSuiteGeneratedByLLM(
var imports: MutableSet<String> = mutableSetOf(),
var packageName: String = "",
var runWith: String = "",
var annotation: String = "",
var otherInfo: String = "",
var testCases: MutableList<TestCaseGeneratedByLLM> = mutableListOf(),
) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ class JavaJUnitTestSuiteParser(

return JUnitTestSuiteParserStrategy.parseJUnitTestSuite(
rawText,
junitVersion,
javaImportPattern,
packageName,
testNamePattern = "void",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ class KotlinJUnitTestSuiteParser(

return JUnitTestSuiteParserStrategy.parseJUnitTestSuite(
rawText,
junitVersion,
kotlinImportPattern,
packageName,
testNamePattern = "fun",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ class JUnitTestSuiteParserStrategy {
companion object {
fun parseJUnitTestSuite(
rawText: String,
junitVersion: JUnitVersion,
importPattern: Regex,
packageName: String,
testNamePattern: String,
Expand All @@ -41,8 +40,9 @@ class JUnitTestSuiteParserStrategy {
.map { it.groupValues[0] }
.toMutableSet()

// save RunWith
val runWith: String = junitVersion.runWithAnnotationMeta.extract(rawCode) ?: ""
// save ExtendWith or RunWith annotation if present
val runWithAnnotation: String = JUnitVersion.JUnit4.runWithAnnotationMeta.extract(rawCode) ?: ""
val annotation = JUnitVersion.JUnit5.runWithAnnotationMeta.extract(rawCode) ?: runWithAnnotation

val testSet: MutableList<String> = rawCode.split("@Test").toMutableList()

Expand Down Expand Up @@ -82,7 +82,7 @@ class JUnitTestSuiteParserStrategy {
TestSuiteGeneratedByLLM(
imports = imports,
packageName = packageName,
runWith = runWith,
annotation = annotation,
otherInfo = otherInfo,
testCases = testCases,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.jetbrains.research.testspark.display.utils.ErrorMessageManager
import org.jetbrains.research.testspark.display.utils.template.DisplayUtils
import org.jetbrains.research.testspark.java.JavaPsiClassWrapper
import org.jetbrains.research.testspark.langwrappers.PsiClassWrapper
import org.jetbrains.research.testspark.services.LLMSettingsService
import org.jetbrains.research.testspark.testmanager.java.JavaTestAnalyzer
import org.jetbrains.research.testspark.testmanager.java.JavaTestGenerator
import java.io.File
Expand Down Expand Up @@ -141,8 +142,11 @@ class JavaDisplayUtils : DisplayUtils {
psiJavaFile = (PsiManager.getInstance(project).findFile(virtualFile!!) as PsiJavaFile)
psiClass = PsiElementFactory.getInstance(project).createClass(className.split(".")[0])

if (uiContext!!.testGenerationOutput.runWith.isNotEmpty()) {
psiClass!!.modifierList!!.addAnnotation("RunWith(${uiContext.testGenerationOutput.runWith})")
if (uiContext!!.testGenerationOutput.annotation.isNotEmpty()) {
val junitVersion = project.getService(LLMSettingsService::class.java).state.junitVersion
psiClass!!.modifierList!!.addAnnotation(
"${junitVersion.runWithAnnotationMeta.annotationName}(${uiContext.testGenerationOutput.annotation})",
)
}

psiJavaFile!!.add(psiClass!!)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.jetbrains.research.testspark.display.utils.ErrorMessageManager
import org.jetbrains.research.testspark.display.utils.template.DisplayUtils
import org.jetbrains.research.testspark.kotlin.KotlinPsiClassWrapper
import org.jetbrains.research.testspark.langwrappers.PsiClassWrapper
import org.jetbrains.research.testspark.services.LLMSettingsService
import org.jetbrains.research.testspark.testmanager.kotlin.KotlinTestAnalyzer
import org.jetbrains.research.testspark.testmanager.kotlin.KotlinTestGenerator
import java.io.File
Expand Down Expand Up @@ -144,9 +145,12 @@ class KotlinDisplayUtils : DisplayUtils {
val ktPsiFactory = KtPsiFactory(project)
ktClass = ktPsiFactory.createClass("class ${className.split(".")[0]} {}")

if (uiContext!!.testGenerationOutput.runWith.isNotEmpty()) {
if (uiContext!!.testGenerationOutput.annotation.isNotEmpty()) {
val junitVersion = project.getService(LLMSettingsService::class.java).state.junitVersion
val annotationEntry =
ktPsiFactory.createAnnotationEntry("@RunWith(${uiContext.testGenerationOutput.runWith})")
ktPsiFactory.createAnnotationEntry(
"@${junitVersion.runWithAnnotationMeta.annotationName}(${uiContext.testGenerationOutput.annotation})",
)
ktClass!!.addBefore(annotationEntry, ktClass!!.body)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import com.intellij.psi.PsiFile
import com.intellij.psi.PsiFileFactory
import com.intellij.psi.codeStyle.CodeStyleManager
import org.jetbrains.research.testspark.core.data.TestGenerationData
import org.jetbrains.research.testspark.services.LLMSettingsService
import org.jetbrains.research.testspark.testmanager.template.TestGenerator
import java.io.File

Expand All @@ -21,11 +22,11 @@ object JavaTestGenerator : TestGenerator {
body: String,
imports: Set<String>,
packageString: String,
runWith: String,
annotation: String,
otherInfo: String,
testGenerationData: TestGenerationData,
): String {
var testFullText = printUpperPart(className, imports, packageString, runWith, otherInfo)
var testFullText = printUpperPart(className, imports, packageString, annotation, otherInfo, project)

// Add each test (exclude expected exception)
testFullText += body
Expand Down Expand Up @@ -75,8 +76,9 @@ object JavaTestGenerator : TestGenerator {
className: String,
imports: Set<String>,
packageString: String,
runWith: String,
annotation: String,
otherInfo: String,
project: Project,
): String {
var testText = ""

Expand All @@ -92,9 +94,10 @@ object JavaTestGenerator : TestGenerator {

testText += "\n"

// add runWith if exists
if (runWith.isNotBlank()) {
testText += "@RunWith($runWith)\n"
// add RunWith or ExtendWith annotation if exists
if (annotation.isNotBlank()) {
val junitVersion = project.getService(LLMSettingsService::class.java).state.junitVersion
testText += "@${junitVersion.runWithAnnotationMeta.annotationName}($annotation)\n"
}
// open the test class
testText += "public class $className {\n\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import com.intellij.psi.PsiFileFactory
import com.intellij.psi.codeStyle.CodeStyleManager
import org.jetbrains.kotlin.idea.KotlinLanguage
import org.jetbrains.research.testspark.core.data.TestGenerationData
import org.jetbrains.research.testspark.services.LLMSettingsService
import org.jetbrains.research.testspark.testmanager.template.TestGenerator
import java.io.File

Expand All @@ -21,14 +22,14 @@ object KotlinTestGenerator : TestGenerator {
body: String,
imports: Set<String>,
packageString: String,
runWith: String,
annotation: String,
otherInfo: String,
testGenerationData: TestGenerationData,
): String {
log.debug("[KotlinClassBuilderHelper] Generate code for $className")

var testFullText =
printUpperPart(className, imports, packageString, runWith, otherInfo)
printUpperPart(className, imports, packageString, annotation, otherInfo, project)

// Add each test (exclude expected exception)
testFullText += body
Expand Down Expand Up @@ -71,8 +72,9 @@ object KotlinTestGenerator : TestGenerator {
className: String,
imports: Set<String>,
packageString: String,
runWith: String,
annotation: String,
otherInfo: String,
project: Project,
): String {
var testText = ""

Expand All @@ -88,9 +90,10 @@ object KotlinTestGenerator : TestGenerator {

testText += "\n"

// Add runWith if exists
if (runWith.isNotBlank()) {
testText += "@RunWith($runWith::class)\n"
// Add ExtendWith or RunWith annotation if exists
if (annotation.isNotBlank()) {
val junitVersion = project.getService(LLMSettingsService::class.java).state.junitVersion
testText += "@${junitVersion.runWithAnnotationMeta.annotationName}($annotation::class)\n"
}

// Open the test class
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ interface TestGenerator {
* @param body the body of the test class
* @param imports the set of imports needed in the test class
* @param packageString the package declaration of the test class
* @param runWith the runWith annotation for the test class
* @param annotation the RunWith or ExtendWith annotation for the test class
* @param otherInfo any other additional information for the test class
* @param testGenerationData the data used for test generation
* @return the generated code as a string
Expand All @@ -26,7 +26,7 @@ interface TestGenerator {
body: String,
imports: Set<String>,
packageString: String,
runWith: String,
annotation: String,
otherInfo: String,
testGenerationData: TestGenerationData,
): String
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ object ToolUtils {
code,
generatedTestData.importsCode,
generatedTestData.packageName,
generatedTestData.runWith,
generatedTestData.annotation,
generatedTestData.otherInfo,
generatedTestData,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ class JUnitTestsAssembler(
val testSuite = testSuiteParser.parseTestSuite(super.getContent())

// save RunWith
if (testSuite?.runWith?.isNotBlank() == true) {
generationData.runWith = testSuite.runWith
if (testSuite?.annotation?.isNotBlank() == true) {
generationData.annotation = testSuite.annotation
generationData.importsCode.add(junitVersion.runWithAnnotationMeta.import)
} else {
generationData.runWith = ""
generationData.annotation = ""
generationData.importsCode.remove(junitVersion.runWithAnnotationMeta.import)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class JUnitTestSuitePresenter(
testBody,
imports,
packageName,
runWith,
annotation,
otherInfo,
generatedTestsData,
)
Expand All @@ -65,7 +65,7 @@ class JUnitTestSuitePresenter(
testCases[testCaseIndex].toStringWithoutExpectedException() + "\n",
imports,
packageName,
runWith,
annotation,
otherInfo,
generatedTestsData,
)
Expand All @@ -89,7 +89,7 @@ class JUnitTestSuitePresenter(
testBody,
imports,
packageName,
runWith,
annotation,
otherInfo,
generatedTestsData,
)
Expand Down