Skip to content

Commit

Permalink
started WIP fir plugin, disabled for now
Browse files Browse the repository at this point in the history
  • Loading branch information
Jolanrensen committed Mar 24, 2024
1 parent 4c17859 commit 2f62d07
Show file tree
Hide file tree
Showing 11 changed files with 283 additions and 15 deletions.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ package org.jetbrains.kotlinx.spark.api.compilerPlugin
import org.jetbrains.kotlin.backend.common.extensions.IrGenerationExtension
import org.jetbrains.kotlin.compiler.plugin.CompilerPluginRegistrar
import org.jetbrains.kotlin.config.CompilerConfiguration
import org.jetbrains.kotlin.fir.extensions.FirExtensionRegistrar
import org.jetbrains.kotlin.fir.extensions.FirExtensionRegistrarAdapter
import org.jetbrains.kotlinx.spark.api.Artifacts
import org.jetbrains.kotlinx.spark.api.compilerPlugin.ir.SparkifyIrGenerationExtension

open class SparkifyCompilerPluginRegistrar: CompilerPluginRegistrar() {
open class SparkifyCompilerPluginRegistrar : CompilerPluginRegistrar() {
init {
println("SparkifyCompilerPluginRegistrar loaded")
}
Expand All @@ -26,6 +28,15 @@ open class SparkifyCompilerPluginRegistrar: CompilerPluginRegistrar() {
val productFqNames = // TODO: get from configuration
listOf("scala.Product")

// Front end (FIR)
// FirExtensionRegistrarAdapter.registerExtension(
// SparkifyFirPluginRegistrar(
// sparkifyAnnotationFqNames = sparkifyAnnotationFqNames,
// productFqNames = productFqNames,
// )
// )

// Intermediate Representation IR
IrGenerationExtension.registerExtension(
SparkifyIrGenerationExtension(
sparkifyAnnotationFqNames = sparkifyAnnotationFqNames,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package org.jetbrains.kotlinx.spark.api.compilerPlugin

import org.jetbrains.kotlin.fir.extensions.FirExtensionRegistrar
import org.jetbrains.kotlinx.spark.api.compilerPlugin.fir.DataClassSparkifyFunctionsGenerator
import org.jetbrains.kotlinx.spark.api.compilerPlugin.fir.DataClassSparkifySuperTypeGenerator

// Potential future K2 FIR hook
// TODO
class SparkifyFirPluginRegistrar(
private val sparkifyAnnotationFqNames: List<String>,
private val productFqNames: List<String>
) : FirExtensionRegistrar() {
override fun ExtensionRegistrarContext.configurePlugin() {
+DataClassSparkifySuperTypeGenerator.builder(
sparkifyAnnotationFqNames = sparkifyAnnotationFqNames,
productFqNames = productFqNames,
)
+DataClassSparkifyFunctionsGenerator.builder(
sparkifyAnnotationFqNames = sparkifyAnnotationFqNames,
productFqNames = productFqNames,
)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
package org.jetbrains.kotlinx.spark.api.compilerPlugin.fir

import org.jetbrains.kotlin.GeneratedDeclarationKey
import org.jetbrains.kotlin.fir.FirSession
import org.jetbrains.kotlin.fir.declarations.utils.isData
import org.jetbrains.kotlin.fir.extensions.FirDeclarationGenerationExtension
import org.jetbrains.kotlin.fir.extensions.MemberGenerationContext
import org.jetbrains.kotlin.fir.plugin.createMemberFunction
import org.jetbrains.kotlin.fir.render
import org.jetbrains.kotlin.fir.resolve.getSuperTypes
import org.jetbrains.kotlin.fir.symbols.impl.FirClassSymbol
import org.jetbrains.kotlin.fir.symbols.impl.FirNamedFunctionSymbol
import org.jetbrains.kotlin.fir.types.toClassSymbol
import org.jetbrains.kotlin.name.CallableId
import org.jetbrains.kotlin.name.Name

class DataClassSparkifyFunctionsGenerator(
session: FirSession,
private val sparkifyAnnotationFqNames: List<String>,
private val productFqNames: List<String>,
) : FirDeclarationGenerationExtension(session) {

companion object {
fun builder(
sparkifyAnnotationFqNames: List<String>,
productFqNames: List<String>
): (FirSession) -> FirDeclarationGenerationExtension = {
DataClassSparkifyFunctionsGenerator(
session = it,
sparkifyAnnotationFqNames = sparkifyAnnotationFqNames,
productFqNames = productFqNames,
)
}

// functions to generate
val canEqual = Name.identifier("canEqual")
val productElement = Name.identifier("productElement")
val productArity = Name.identifier("productArity")
}

override fun generateFunctions(
callableId: CallableId,
context: MemberGenerationContext?
): List<FirNamedFunctionSymbol> {
val owner = context?.owner ?: return emptyList()

val functionName = callableId.callableName
val superTypes = owner.getSuperTypes(session)
val superProduct = superTypes.first {
it.toString().endsWith("Product")
}.toClassSymbol(session)!!
val superEquals = superTypes.first {
it.toString().endsWith("Equals")
}.toClassSymbol(session)!!

val function = when (functionName) {
canEqual -> {
val func = createMemberFunction(
owner = owner,
key = Key,
name = functionName,
returnType = session.builtinTypes.booleanType.type,
) {
valueParameter(
name = Name.identifier("that"),
type = session.builtinTypes.nullableAnyType.type,
)
}
// val superFunction = superEquals.declarationSymbols.first {
// it is FirNamedFunctionSymbol && it.name == functionName
// } as FirNamedFunctionSymbol
// overrides(func, superFunction)
func
}

productElement -> {
createMemberFunction(
owner = owner,
key = Key,
name = functionName,
returnType = session.builtinTypes.nullableAnyType.type,
) {
valueParameter(
name = Name.identifier("n"),
type = session.builtinTypes.intType.type,
)
}
}

productArity -> {
createMemberFunction(
owner = owner,
key = Key,
name = functionName,
returnType = session.builtinTypes.intType.type,
)
}

else -> {
return emptyList()
}
}

return listOf(function.symbol)
}

override fun getCallableNamesForClass(classSymbol: FirClassSymbol<*>, context: MemberGenerationContext): Set<Name> =
if (classSymbol.isData && classSymbol.annotations.any { "Sparkify" in it.render() }) {
setOf(canEqual, productElement, productArity)
} else {
emptySet()
}

object Key : GeneratedDeclarationKey()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package org.jetbrains.kotlinx.spark.api.compilerPlugin.fir

import org.jetbrains.kotlin.fir.FirSession
import org.jetbrains.kotlin.fir.declarations.FirClassLikeDeclaration
import org.jetbrains.kotlin.fir.declarations.utils.isData
import org.jetbrains.kotlin.fir.extensions.FirSupertypeGenerationExtension
import org.jetbrains.kotlin.fir.render
import org.jetbrains.kotlin.fir.resolve.fqName
import org.jetbrains.kotlin.fir.symbols.impl.ConeClassLikeLookupTagImpl
import org.jetbrains.kotlin.fir.types.FirResolvedTypeRef
import org.jetbrains.kotlin.fir.types.builder.buildResolvedTypeRef
import org.jetbrains.kotlin.fir.types.impl.ConeClassLikeTypeImpl
import org.jetbrains.kotlin.name.ClassId
import org.jetbrains.kotlin.name.FqName

/**
* This class tells the FIR that all @Sparkify annotated data classes
* get [scala.Product] as their super type.
*/
class DataClassSparkifySuperTypeGenerator(
session: FirSession,
private val sparkifyAnnotationFqNames: List<String>,
private val productFqNames: List<String>,
) : FirSupertypeGenerationExtension(session) {

companion object {
fun builder(sparkifyAnnotationFqNames: List<String>, productFqNames: List<String>): (FirSession) -> FirSupertypeGenerationExtension = {
DataClassSparkifySuperTypeGenerator(
session = it,
sparkifyAnnotationFqNames = sparkifyAnnotationFqNames,
productFqNames = productFqNames,
)
}
}

context(TypeResolveServiceContainer)
override fun computeAdditionalSupertypes(
classLikeDeclaration: FirClassLikeDeclaration,
resolvedSupertypes: List<FirResolvedTypeRef>
): List<FirResolvedTypeRef> = listOf(
buildResolvedTypeRef {
val scalaProduct = productFqNames.first().let {
ClassId.topLevel(FqName(it))
}
type = ConeClassLikeTypeImpl(
lookupTag = ConeClassLikeLookupTagImpl(scalaProduct),
typeArguments = emptyArray(),
isNullable = false,
)
}

)

override fun needTransformSupertypes(declaration: FirClassLikeDeclaration): Boolean =
declaration.symbol.isData &&
declaration.annotations.any {
"Sparkify" in it.render()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ import org.jetbrains.kotlin.name.FqName
import org.jetbrains.kotlin.name.Name
import org.jetbrains.kotlin.name.SpecialNames

class DataClassPropertyAnnotationGenerator(
class DataClassSparkifyGenerator(
private val pluginContext: IrPluginContext,
private val sparkifyAnnotationFqNames: List<String>,
private val columnNameAnnotationFqNames: List<String>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import org.jetbrains.kotlin.backend.common.extensions.IrGenerationExtension
import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext
import org.jetbrains.kotlin.ir.declarations.IrModuleFragment
import org.jetbrains.kotlin.ir.visitors.acceptChildrenVoid
import org.jetbrains.kotlinx.spark.api.compilerPlugin.ir.DataClassPropertyAnnotationGenerator

class SparkifyIrGenerationExtension(
private val sparkifyAnnotationFqNames: List<String>,
Expand All @@ -13,7 +12,7 @@ class SparkifyIrGenerationExtension(
) : IrGenerationExtension {
override fun generate(moduleFragment: IrModuleFragment, pluginContext: IrPluginContext) {
val visitors = listOf(
DataClassPropertyAnnotationGenerator(
DataClassSparkifyGenerator(
pluginContext = pluginContext,
sparkifyAnnotationFqNames = sparkifyAnnotationFqNames,
columnNameAnnotationFqNames = columnNameAnnotationFqNames,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@


package org.jetbrains.kotlinx.spark.api.compilerPlugin.runners;

import com.intellij.testFramework.TestDataPath;
import org.jetbrains.kotlin.test.util.KtTestUtil;
import org.jetbrains.kotlin.test.TestMetadata;
import org.junit.jupiter.api.Test;

import java.io.File;
import java.util.regex.Pattern;

/** This class is generated by {@link org.jetbrains.kotlinx.spark.api.compilerPlugin.GenerateTestsKt}. DO NOT MODIFY MANUALLY */
@SuppressWarnings("all")
@TestMetadata("/mnt/data/Projects/kotlin-spark-api/compiler-plugin/src/test/resources/testData/diagnostics")
@TestDataPath("$PROJECT_ROOT")
public class DiagnosticTestGenerated extends AbstractDiagnosticTest {
@Test
public void testAllFilesPresentInDiagnostics() {
KtTestUtil.assertAllTestsPresentByMetadataWithExcluded(this.getClass(), new File("/mnt/data/Projects/kotlin-spark-api/compiler-plugin/src/test/resources/testData/diagnostics"), Pattern.compile("^(.+)\\.kt$"), null, true);
}

@Test
@TestMetadata("dataClassTest.kt")
public void testDataClassTest() {
runTest("/mnt/data/Projects/kotlin-spark-api/compiler-plugin/src/test/resources/testData/diagnostics/dataClassTest.kt");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@ package org.jetbrains.kotlinx.spark.api.compilerPlugin
import org.jetbrains.kotlin.generators.generateTestGroupSuiteWithJUnit5
import org.jetbrains.kotlinx.spark.api.Artifacts
import org.jetbrains.kotlinx.spark.api.compilerPlugin.runners.AbstractBoxTest
import org.jetbrains.kotlinx.spark.api.compilerPlugin.runners.AbstractDiagnosticTest

fun main() {
generateTestGroupSuiteWithJUnit5 {
testGroup(
testDataRoot = "${Artifacts.projectRoot}/${Artifacts.compilerPluginArtifactId}/src/test/resources/testData",
testsRoot = "${Artifacts.projectRoot}/${Artifacts.compilerPluginArtifactId}/src/test-gen/kotlin",
) {
// testClass<AbstractDiagnosticTest> {
// model("diagnostics")
// }
testClass<AbstractDiagnosticTest> {
model("diagnostics")
}

testClass<AbstractBoxTest> {
model("box")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ package org.jetbrains.kotlinx.spark.api.compilerPlugin.services
import org.jetbrains.kotlin.backend.common.extensions.IrGenerationExtension
import org.jetbrains.kotlin.compiler.plugin.CompilerPluginRegistrar
import org.jetbrains.kotlin.config.CompilerConfiguration
import org.jetbrains.kotlin.fir.extensions.FirExtensionRegistrarAdapter
import org.jetbrains.kotlin.test.model.TestModule
import org.jetbrains.kotlin.test.services.EnvironmentConfigurator
import org.jetbrains.kotlin.test.services.TestServices
import org.jetbrains.kotlinx.spark.api.compilerPlugin.SparkifyFirPluginRegistrar
import org.jetbrains.kotlinx.spark.api.compilerPlugin.ir.SparkifyIrGenerationExtension

class ExtensionRegistrarConfigurator(testServices: TestServices) : EnvironmentConfigurator(testServices) {
Expand All @@ -16,6 +18,16 @@ class ExtensionRegistrarConfigurator(testServices: TestServices) : EnvironmentCo
val sparkifyAnnotationFqNames = listOf("foo.bar.Sparkify")
val columnNameAnnotationFqNames = listOf("foo.bar.ColumnName")
val productFqNames = listOf("foo.bar.Product")

// Front end (FIR)
// FirExtensionRegistrarAdapter.registerExtension(
// SparkifyFirPluginRegistrar(
// sparkifyAnnotationFqNames = sparkifyAnnotationFqNames,
// productFqNames = productFqNames,
// )
// )

// Intermediate Representation IR
IrGenerationExtension.registerExtension(
SparkifyIrGenerationExtension(
sparkifyAnnotationFqNames = sparkifyAnnotationFqNames,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package foo.bar

annotation class Sparkify
annotation class ColumnName(val name: String)

// Fake Equals
interface Equals {
fun canEqual(that: Any?): Boolean
}

// Fake Product
interface Product: Equals {
fun productElement(n: Int): Any
fun productArity(): Int
}

fun test() {
val user = User()
user.productArity() // should not be an error
}

@Sparkify
data <!ABSTRACT_MEMBER_NOT_IMPLEMENTED!>class User<!>(
val name: String = "John Doe",
val age: Int = 25,
@ColumnName("a") val test: Double = 1.0,
@get:ColumnName("b") val test2: Double = 2.0,
)

0 comments on commit 2f62d07

Please sign in to comment.