From b3f70d07786cbd95b8a891dbffaaeddb732ebfd0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Doeraene?= Date: Tue, 21 May 2024 10:15:50 +0200 Subject: [PATCH 01/11] Allocate reflective proxy IDs from the Preprocessor. We can allocate the IDs without looking at the body of methods, so this allows to have a fixed assignment from the start. --- .../backend/wasmemitter/Preprocessor.scala | 16 ++++++++++++++- .../backend/wasmemitter/WasmContext.scala | 20 +++++++++---------- 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/Preprocessor.scala b/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/Preprocessor.scala index b759925..66d009b 100644 --- a/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/Preprocessor.scala +++ b/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/Preprocessor.scala @@ -1,5 +1,7 @@ package org.scalajs.linker.backend.wasmemitter +import scala.collection.mutable + import org.scalajs.ir.Names._ import org.scalajs.ir.Trees._ import org.scalajs.ir.Types._ @@ -16,9 +18,21 @@ object Preprocessor { ): Unit = { val staticFieldMirrors = computeStaticFieldMirrors(tles) - for (clazz <- classes) + val definedReflectiveProxyNames = mutable.HashSet.empty[MethodName] + + for (clazz <- classes) { preprocess(clazz, staticFieldMirrors.getOrElse(clazz.className, Map.empty)) + // For Scala classes, collect the reflective proxy method names that it defines + if (clazz.kind.isClass || clazz.kind == ClassKind.HijackedClass) { + for (method <- clazz.methods if method.methodName.isReflectiveProxy) + definedReflectiveProxyNames += method.methodName + } + } + + // sort for stability + ctx.setReflectiveProxyIDs(definedReflectiveProxyNames.toList.sorted.zipWithIndex.toMap) + val collector = new AbstractMethodCallCollector(ctx) for (clazz <- classes) collector.collectAbstractMethodCalls(clazz) diff --git a/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala b/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala index 35bfd08..a6a0fb8 100644 --- a/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala +++ b/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala @@ -29,6 +29,7 @@ final class WasmContext { import WasmContext._ private val classInfo = mutable.Map[ClassName, ClassInfo]() + private var reflectiveProxies: Map[MethodName, Int] = null private var _itablesLength: Int = 0 def itablesLength = _itablesLength @@ -38,7 +39,6 @@ final class WasmContext { private val constantStringGlobals = LinkedHashMap.empty[String, StringData] private val classItableGlobals = mutable.ListBuffer.empty[ClassName] private val closureDataTypes = LinkedHashMap.empty[List[Type], wanme.TypeID] - private val reflectiveProxies = LinkedHashMap.empty[MethodName, Int] val moduleBuilder: ModuleBuilder = { new ModuleBuilder(new ModuleBuilder.FunctionTypeProvider { @@ -59,7 +59,6 @@ final class WasmContext { private var nextConstatnStringOffset: Int = 0 private var nextArrayTypeIndex: Int = 1 private var nextClosureDataTypeIndex: Int = 1 - private var nextReflectiveProxyIdx: Int = 0 private val _importedModules: mutable.LinkedHashSet[String] = new mutable.LinkedHashSet() @@ -99,15 +98,16 @@ final class WasmContext { ArrayType(typeRef) } - /** Retrieves a unique identifier for a reflective proxy with the given name */ + /** Sets the map of reflexity proxy IDs, only for use by `Preprocessor`. */ + def setReflectiveProxyIDs(proxyIDs: Map[MethodName, Int]): Unit = + reflectiveProxies = proxyIDs + + /** Retrieves a unique identifier for a reflective proxy with the given name. + * + * If no class defines a reflective proxy with the given name, returns `-1`. + */ def getReflectiveProxyId(name: MethodName): Int = - reflectiveProxies.getOrElseUpdate( - name, { - val idx = nextReflectiveProxyIdx - nextReflectiveProxyIdx += 1 - idx - } - ) + reflectiveProxies.getOrElse(name, -1) /** Adds or reuses a function type for a table function. * From eb0fb7163626270f4d96179c6cdb1ad58f4dd2e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Doeraene?= Date: Tue, 21 May 2024 11:41:11 +0200 Subject: [PATCH 02/11] Remove dead code variable `WasmContext.nextArrayTypeIndex`. --- .../org/scalajs/linker/backend/wasmemitter/WasmContext.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala b/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala index a6a0fb8..09696a8 100644 --- a/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala +++ b/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala @@ -57,7 +57,6 @@ final class WasmContext { private var stringPool = new mutable.ArrayBuffer[Byte]() private var nextConstantStringIndex: Int = 0 private var nextConstatnStringOffset: Int = 0 - private var nextArrayTypeIndex: Int = 1 private var nextClosureDataTypeIndex: Int = 1 private val _importedModules: mutable.LinkedHashSet[String] = From d71e15f887becc7fba0605876be39e4e25d307d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Doeraene?= Date: Tue, 21 May 2024 11:42:55 +0200 Subject: [PATCH 03/11] Do not explicitly store the next string offset. We can use `stringPool.size` instead. --- .../org/scalajs/linker/backend/wasmemitter/WasmContext.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala b/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala index 09696a8..08f352b 100644 --- a/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala +++ b/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala @@ -56,7 +56,6 @@ final class WasmContext { private var stringPool = new mutable.ArrayBuffer[Byte]() private var nextConstantStringIndex: Int = 0 - private var nextConstatnStringOffset: Int = 0 private var nextClosureDataTypeIndex: Int = 1 private val _importedModules: mutable.LinkedHashSet[String] = @@ -151,13 +150,12 @@ final class WasmContext { val bytes = str.toCharArray.flatMap { char => Array((char & 0xFF).toByte, (char >> 8).toByte) } - val offset = nextConstatnStringOffset + val offset = stringPool.size val data = StringData(nextConstantStringIndex, offset) constantStringGlobals(str) = data stringPool ++= bytes nextConstantStringIndex += 1 - nextConstatnStringOffset += bytes.length data } } From 361a32f685e5d00a33bde8192b790e9cae0fa9d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Doeraene?= Date: Tue, 21 May 2024 10:26:51 +0200 Subject: [PATCH 04/11] Move `assignBuckets` to `Preprocessor`. --- .../backend/wasmemitter/Preprocessor.scala | 149 ++++++++++++++++- .../backend/wasmemitter/WasmContext.scala | 157 +----------------- 2 files changed, 157 insertions(+), 149 deletions(-) diff --git a/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/Preprocessor.scala b/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/Preprocessor.scala index 66d009b..d4e0d9a 100644 --- a/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/Preprocessor.scala +++ b/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/Preprocessor.scala @@ -42,7 +42,7 @@ object Preprocessor { for (clazz <- classes) { ctx.getClassInfo(clazz.className).buildMethodTable() } - ctx.assignBuckets(classes) + assignBuckets(classes) } private def computeStaticFieldMirrors( @@ -226,4 +226,151 @@ object Preprocessor { } } } + + /** Group interface types + types that implements any interfaces into buckets, where no two types + * in the same bucket can have common subtypes. + * + * It allows compressing the itable by reusing itable's index (buckets) for unrelated types, + * instead of having a 1-1 mapping from type to index. As a result, the itables' length will be + * the same as the number of buckets). + * + * The algorithm separates the type hierarchy into three disjoint subsets, + * + * - join types: types with multiple parents (direct supertypes) that have only single + * subtyping descendants: `join(T) = {x ∈ multis(T) | ∄ y ∈ multis(T) : y <: x}` where + * multis(T) means types with multiple direct supertypes. + * - spine types: all ancestors of join types: `spine(T) = {x ∈ T | ∃ y ∈ join(T) : x ∈ + * ancestors(y)}` + * - plain types: types that are neither join nor spine types + * + * The bucket assignment process consists of two parts: + * + * **1. Assign buckets to spine types** + * + * Two spine types can share the same bucket only if they do not have any common join type + * descendants. + * + * Visit spine types in reverse topological order because (from leaves to root) when assigning a + * a spine type to bucket, the algorithm already has the complete information about the + * join/spine type descendants of that spine type. + * + * Assign a bucket to a spine type if adding it doesn't violate the bucket assignment rule: two + * spine types can share a bucket only if they don't have any common join type descendants. If no + * existing bucket satisfies the rule, create a new bucket. + * + * **2. Assign buckets to non-spine types (plain and join types)** + * + * Visit these types in level order (from root to leaves) For each type, compute the set of + * buckets already used by its ancestors. Assign the type to any available bucket not in this + * set. If no available bucket exists, create a new one. + * + * To test if type A is a subtype of type B: load the bucket index of type B (we do this by + * `getItableIdx`), load the itable at that index from A, and check if the itable is an itable + * for B. + * + * @see + * This algorithm is based on the "packed encoding" presented in the paper "Efficient Type + * Inclusion Tests" + * [[https://www.researchgate.net/publication/2438441_Efficient_Type_Inclusion_Tests]] + */ + private def assignBuckets(allClasses: List[LinkedClass])(implicit ctx: WasmContext): Unit = { + val classes = allClasses.filterNot(_.kind.isJSType) + + var nextIdx = 0 + def newBucket(): Bucket = { + val idx = nextIdx + nextIdx += 1 + new Bucket(idx) + } + def getAllInterfaces(info: ClassInfo): List[ClassName] = + info.ancestors.filter(ctx.getClassInfo(_).isInterface) + + val buckets = new mutable.ListBuffer[Bucket]() + + /** All join type descendants of the class */ + val joinsOf = + new mutable.HashMap[ClassName, mutable.HashSet[ClassName]]() + + /** the buckets that have been assigned to any of the ancestors of the class */ + val usedOf = new mutable.HashMap[ClassName, mutable.HashSet[Bucket]]() + val spines = new mutable.HashSet[ClassName]() + + for (clazz <- classes.reverseIterator) { + val info = ctx.getClassInfo(clazz.name.name) + val ifaces = getAllInterfaces(info) + if (ifaces.nonEmpty) { + val joins = joinsOf.getOrElse(clazz.name.name, new mutable.HashSet()) + + if (joins.nonEmpty) { // spine type + var found = false + val bs = buckets.iterator + // look for an existing bucket to add the spine type to + while (!found && bs.hasNext) { + val b = bs.next() + // two spine types can share a bucket only if they don't have any common join type descendants + if (!b.joins.exists(joins)) { + found = true + b.add(info) + b.joins ++= joins + } + } + if (!found) { // there's no bucket to add, create new bucket + val b = newBucket() + b.add(info) + buckets.append(b) + b.joins ++= joins + } + for (iface <- ifaces) { + joinsOf.getOrElseUpdate(iface, new mutable.HashSet()) ++= joins + } + spines.add(clazz.name.name) + } else if (ifaces.length > 1) { // join type, add to joins map, bucket assignment is done later + ifaces.foreach { iface => + joinsOf.getOrElseUpdate(iface, new mutable.HashSet()) += clazz.name.name + } + } + // else: plain, do nothing + } + + } + + for (clazz <- classes) { + val info = ctx.getClassInfo(clazz.name.name) + val ifaces = getAllInterfaces(info) + if (ifaces.nonEmpty && !spines.contains(clazz.name.name)) { + val used = usedOf.getOrElse(clazz.name.name, new mutable.HashSet()) + for { + iface <- ifaces + parentUsed <- usedOf.get(iface) + } { used ++= parentUsed } + + var found = false + val bs = buckets.iterator + while (!found && bs.hasNext) { + val b = bs.next() + if (!used.contains(b)) { + found = true + b.add(info) + used.add(b) + } + } + if (!found) { + val b = newBucket() + buckets.append(b) + b.add(info) + used.add(b) + } + } + } + + ctx.setItablesLength(buckets.length) + } + + private final class Bucket(idx: Int) { + def add(clazz: ClassInfo) = clazz.setItableIdx((idx)) + + /** A set of join types that are descendants of the types assigned to that bucket */ + val joins = new mutable.HashSet[ClassName]() + } + } diff --git a/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala b/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala index 08f352b..29f6c48 100644 --- a/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala +++ b/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala @@ -30,9 +30,16 @@ final class WasmContext { private val classInfo = mutable.Map[ClassName, ClassInfo]() private var reflectiveProxies: Map[MethodName, Int] = null + private var _itablesLength: Int = -1 - private var _itablesLength: Int = 0 - def itablesLength = _itablesLength + /** Sets the length of the itables arrays, only for use by `Preprocessor`. */ + def setItablesLength(length: Int): Unit = + _itablesLength = length + + def itablesLength: Int = { + require(_itablesLength != -1, s"itablesLength was not yet assigned") + _itablesLength + } private val functionTypes = LinkedHashMap.empty[watpe.FunctionType, wanme.TypeID] private val tableFunctionTypes = mutable.HashMap.empty[MethodName, wanme.TypeID] @@ -200,9 +207,6 @@ final class WasmContext { wa.RefFunc(name) } - def assignBuckets(classes: List[LinkedClass]): Unit = - _itablesLength = assignBuckets0(classes.filterNot(_.kind.isJSType)) - def addExport(exprt: wamod.Export): Unit = moduleBuilder.addExport(exprt) @@ -254,142 +258,6 @@ final class WasmContext { def getAllFuncDeclarations(): List[wanme.FunctionID] = _funcDeclarations.toList - - /** Group interface types + types that implements any interfaces into buckets, where no two types - * in the same bucket can have common subtypes. - * - * It allows compressing the itable by reusing itable's index (buckets) for unrelated types, - * instead of having a 1-1 mapping from type to index. As a result, the itables' length will be - * the same as the number of buckets). - * - * The algorithm separates the type hierarchy into three disjoint subsets, - * - * - join types: types with multiple parents (direct supertypes) that have only single - * subtyping descendants: `join(T) = {x ∈ multis(T) | ∄ y ∈ multis(T) : y <: x}` where - * multis(T) means types with multiple direct supertypes. - * - spine types: all ancestors of join types: `spine(T) = {x ∈ T | ∃ y ∈ join(T) : x ∈ - * ancestors(y)}` - * - plain types: types that are neither join nor spine types - * - * The bucket assignment process consists of two parts: - * - * **1. Assign buckets to spine types** - * - * Two spine types can share the same bucket only if they do not have any common join type - * descendants. - * - * Visit spine types in reverse topological order because (from leaves to root) when assigning a - * a spine type to bucket, the algorithm already has the complete information about the - * join/spine type descendants of that spine type. - * - * Assign a bucket to a spine type if adding it doesn't violate the bucket assignment rule: two - * spine types can share a bucket only if they don't have any common join type descendants. If no - * existing bucket satisfies the rule, create a new bucket. - * - * **2. Assign buckets to non-spine types (plain and join types)** - * - * Visit these types in level order (from root to leaves) For each type, compute the set of - * buckets already used by its ancestors. Assign the type to any available bucket not in this - * set. If no available bucket exists, create a new one. - * - * To test if type A is a subtype of type B: load the bucket index of type B (we do this by - * `getItableIdx`), load the itable at that index from A, and check if the itable is an itable - * for B. - * - * @see - * This algorithm is based on the "packed encoding" presented in the paper "Efficient Type - * Inclusion Tests" - * [[https://www.researchgate.net/publication/2438441_Efficient_Type_Inclusion_Tests]] - */ - private def assignBuckets0(classes: List[LinkedClass]): Int = { - var nextIdx = 0 - def newBucket(): Bucket = { - val idx = nextIdx - nextIdx += 1 - new Bucket(idx) - } - def getAllInterfaces(info: ClassInfo): List[ClassName] = - info.ancestors.filter(getClassInfo(_).isInterface) - - val buckets = new mutable.ListBuffer[Bucket]() - - /** All join type descendants of the class */ - val joinsOf = - new mutable.HashMap[ClassName, mutable.HashSet[ClassName]]() - - /** the buckets that have been assigned to any of the ancestors of the class */ - val usedOf = new mutable.HashMap[ClassName, mutable.HashSet[Bucket]]() - val spines = new mutable.HashSet[ClassName]() - - for (clazz <- classes.reverseIterator) { - val info = getClassInfo(clazz.name.name) - val ifaces = getAllInterfaces(info) - if (ifaces.nonEmpty) { - val joins = joinsOf.getOrElse(clazz.name.name, new mutable.HashSet()) - - if (joins.nonEmpty) { // spine type - var found = false - val bs = buckets.iterator - // look for an existing bucket to add the spine type to - while (!found && bs.hasNext) { - val b = bs.next() - // two spine types can share a bucket only if they don't have any common join type descendants - if (!b.joins.exists(joins)) { - found = true - b.add(info) - b.joins ++= joins - } - } - if (!found) { // there's no bucket to add, create new bucket - val b = newBucket() - b.add(info) - buckets.append(b) - b.joins ++= joins - } - for (iface <- ifaces) { - joinsOf.getOrElseUpdate(iface, new mutable.HashSet()) ++= joins - } - spines.add(clazz.name.name) - } else if (ifaces.length > 1) { // join type, add to joins map, bucket assignment is done later - ifaces.foreach { iface => - joinsOf.getOrElseUpdate(iface, new mutable.HashSet()) += clazz.name.name - } - } - // else: plain, do nothing - } - - } - - for (clazz <- classes) { - val info = getClassInfo(clazz.name.name) - val ifaces = getAllInterfaces(info) - if (ifaces.nonEmpty && !spines.contains(clazz.name.name)) { - val used = usedOf.getOrElse(clazz.name.name, new mutable.HashSet()) - for { - iface <- ifaces - parentUsed <- usedOf.get(iface) - } { used ++= parentUsed } - - var found = false - val bs = buckets.iterator - while (!found && bs.hasNext) { - val b = bs.next() - if (!used.contains(b)) { - found = true - b.add(info) - used.add(b) - } - } - if (!found) { - val b = newBucket() - buckets.append(b) - b.add(info) - used.add(b) - } - } - } - buckets.length - } } object WasmContext { @@ -547,11 +415,4 @@ object WasmContext { def isEffectivelyFinal: Boolean = effectivelyFinal } - - private[WasmContext] class Bucket(idx: Int) { - def add(clazz: ClassInfo) = clazz.setItableIdx((idx)) - - /** A set of join types that are descendants of the types assigned to that bucket */ - val joins = new mutable.HashSet[ClassName]() - } } From cb31f25237bd1c5d05c3036d0f952ed16d7237da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Doeraene?= Date: Tue, 21 May 2024 10:54:22 +0200 Subject: [PATCH 05/11] Do not store the list of classes with itable globals in `WasmContext`. We can compute that in `Emitter.genStartFunction` instead with no additional cost. --- .../backend/wasmemitter/ClassEmitter.scala | 10 ++--- .../linker/backend/wasmemitter/Emitter.scala | 41 ++++++++++--------- .../backend/wasmemitter/WasmContext.scala | 15 +++---- 3 files changed, 31 insertions(+), 35 deletions(-) diff --git a/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/ClassEmitter.scala b/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/ClassEmitter.scala index 9d53f9b..76bc66a 100644 --- a/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/ClassEmitter.scala +++ b/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/ClassEmitter.scala @@ -510,8 +510,7 @@ class ClassEmitter(coreSpec: CoreSpec) { instrs += wa.GlobalGet(genGlobalID.forVTable(className)) - val interfaces = classInfo.ancestors.map(ctx.getClassInfo(_)).filter(_.isInterface) - if (!interfaces.isEmpty) + if (classInfo.classImplementsAnyInterface) instrs += wa.GlobalGet(genGlobalID.forITable(className)) else instrs += wa.RefNull(watpe.HeapType(genTypeID.itables)) @@ -618,9 +617,8 @@ class ClassEmitter(coreSpec: CoreSpec) { */ private def genGlobalClassItable(clazz: LinkedClass)(implicit ctx: WasmContext): Unit = { val className = clazz.className - val info = ctx.getClassInfo(className) - val implementsAnyInterface = info.ancestors.exists(a => ctx.getClassInfo(a).isInterface) - if (implementsAnyInterface) { + + if (ctx.getClassInfo(className).classImplementsAnyInterface) { val globalName = genGlobalID.forITable(className) val itablesInit = List( wa.I32Const(ctx.itablesLength), @@ -633,7 +631,7 @@ class ClassEmitter(coreSpec: CoreSpec) { wa.Expr(itablesInit), isMutable = false ) - ctx.addGlobalITable(className, global) + ctx.addGlobal(global) } } diff --git a/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/Emitter.scala b/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/Emitter.scala index 23a66ea..fc5abbe 100644 --- a/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/Emitter.scala +++ b/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/Emitter.scala @@ -54,12 +54,9 @@ final class Emitter(config: Emitter.Config) { } CoreWasmLib.genPostClasses() - val classesWithStaticInit = - sortedClasses.filter(_.hasStaticInitializer).map(_.className) - complete( + sortedClasses, module.initializers.toList, - classesWithStaticInit, module.topLevelExports ) @@ -73,8 +70,8 @@ final class Emitter(config: Emitter.Config) { } private def complete( + sortedClasses: List[LinkedClass], moduleInitializers: List[ModuleInitializer.Initializer], - classesWithStaticInit: List[ClassName], topLevelExportDefs: List[LinkedTopLevelExport] )(implicit ctx: WasmContext): Unit = { /* Before generating the string globals in `genStartFunction()`, make sure @@ -114,13 +111,13 @@ final class Emitter(config: Emitter.Config) { ) ) - genStartFunction(moduleInitializers, classesWithStaticInit, topLevelExportDefs) + genStartFunction(sortedClasses, moduleInitializers, topLevelExportDefs) genDeclarativeElements() } private def genStartFunction( + sortedClasses: List[LinkedClass], moduleInitializers: List[ModuleInitializer.Initializer], - classesWithStaticInit: List[ClassName], topLevelExportDefs: List[LinkedTopLevelExport] )(implicit ctx: WasmContext): Unit = { import org.scalajs.ir.Trees._ @@ -132,20 +129,24 @@ final class Emitter(config: Emitter.Config) { val instrs: fb.type = fb // Initialize itables - for (className <- ctx.getAllClassesWithITableGlobal()) { + for (clazz <- sortedClasses if clazz.kind.isClass && clazz.hasDirectInstances) { + val className = clazz.className val classInfo = ctx.getClassInfo(className) - val interfaces = classInfo.ancestors.map(ctx.getClassInfo(_)).filter(_.isInterface) - val resolvedMethodInfos = classInfo.resolvedMethodInfos - interfaces.foreach { iface => - val idx = ctx.getItableIdx(iface) - instrs += wa.GlobalGet(genGlobalID.forITable(className)) - instrs += wa.I32Const(idx) + if (classInfo.classImplementsAnyInterface) { + val interfaces = classInfo.ancestors.map(ctx.getClassInfo(_)).filter(_.isInterface) + val resolvedMethodInfos = classInfo.resolvedMethodInfos - for (method <- iface.tableEntries) - instrs += ctx.refFuncWithDeclaration(resolvedMethodInfos(method).tableEntryName) - instrs += wa.StructNew(genTypeID.forITable(iface.name)) - instrs += wa.ArraySet(genTypeID.itables) + interfaces.foreach { iface => + val idx = ctx.getItableIdx(iface) + instrs += wa.GlobalGet(genGlobalID.forITable(className)) + instrs += wa.I32Const(idx) + + for (method <- iface.tableEntries) + instrs += ctx.refFuncWithDeclaration(resolvedMethodInfos(method).tableEntryName) + instrs += wa.StructNew(genTypeID.forITable(iface.name)) + instrs += wa.ArraySet(genTypeID.itables) + } } } @@ -178,10 +179,10 @@ final class Emitter(config: Emitter.Config) { // Emit the static initializers - for (className <- classesWithStaticInit) { + for (clazz <- sortedClasses if clazz.hasStaticInitializer) { val funcName = genFunctionID.forMethod( MemberNamespace.StaticConstructor, - className, + clazz.className, StaticInitializerName ) instrs += wa.Call(funcName) diff --git a/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala b/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala index 29f6c48..38e4c2a 100644 --- a/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala +++ b/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala @@ -44,7 +44,6 @@ final class WasmContext { private val functionTypes = LinkedHashMap.empty[watpe.FunctionType, wanme.TypeID] private val tableFunctionTypes = mutable.HashMap.empty[MethodName, wanme.TypeID] private val constantStringGlobals = LinkedHashMap.empty[String, StringData] - private val classItableGlobals = mutable.ListBuffer.empty[ClassName] private val closureDataTypes = LinkedHashMap.empty[List[Type], wanme.TypeID] val moduleBuilder: ModuleBuilder = { @@ -216,14 +215,6 @@ final class WasmContext { def addGlobal(g: wamod.Global): Unit = moduleBuilder.addGlobal(g) - def addGlobalITable(name: ClassName, g: wamod.Global): Unit = { - classItableGlobals += name - addGlobal(g) - } - - def getAllClassesWithITableGlobal(): List[ClassName] = - classItableGlobals.toList - def getImportedModuleGlobal(moduleName: String): wanme.GlobalID = { val name = genGlobalID.forImportedModule(moduleName) if (_importedModules.add(moduleName)) { @@ -281,6 +272,12 @@ object WasmContext { val staticFieldMirrors: Map[FieldName, List[String]], private var _itableIdx: Int ) { + + /** Does this Scala class implement any interface? */ + val classImplementsAnyInterface = + if (!kind.isClass && kind != ClassKind.HijackedClass) false + else ancestors.exists(a => a != name && ctx.getClassInfo(a).isInterface) + val resolvedMethodInfos: Map[MethodName, ConcreteMethodInfo] = { if (kind.isClass || kind == ClassKind.HijackedClass) { val inherited: Map[MethodName, ConcreteMethodInfo] = superClass match { From 50b2eb1bc802beec759afc23dc4da50fca51fbb4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Doeraene?= Date: Tue, 21 May 2024 11:01:37 +0200 Subject: [PATCH 06/11] Make inheritance info from `ClassInfo` private. The use sites that need that information already have a `LinkedClass` from which they can get it. --- .../linker/backend/wasmemitter/ClassEmitter.scala | 11 ++++++----- .../scalajs/linker/backend/wasmemitter/Emitter.scala | 2 +- .../linker/backend/wasmemitter/Preprocessor.scala | 9 ++++----- .../linker/backend/wasmemitter/WasmContext.scala | 5 ++--- 4 files changed, 13 insertions(+), 14 deletions(-) diff --git a/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/ClassEmitter.scala b/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/ClassEmitter.scala index 76bc66a..bfaad56 100644 --- a/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/ClassEmitter.scala +++ b/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/ClassEmitter.scala @@ -184,7 +184,7 @@ class ClassEmitter(coreSpec: CoreSpec) { } val strictAncestorsValue: List[wa.Instr] = { - val ancestors = ctx.getClassInfo(className).ancestors + val ancestors = clazz.ancestors // By spec, the first element of `ancestors` is always the class itself assert( @@ -205,7 +205,7 @@ class ClassEmitter(coreSpec: CoreSpec) { val cloneFunction = { // If the class is concrete and implements the `java.lang.Cloneable`, // `genCloneFunction` should've generated the clone function - if (!classInfo.isAbstract && classInfo.ancestors.contains(CloneableClass)) + if (!classInfo.isAbstract && clazz.ancestors.contains(CloneableClass)) wa.RefFunc(genFunctionID.clone(className)) else wa.RefNull(watpe.HeapType.NoFunc) @@ -283,7 +283,7 @@ class ClassEmitter(coreSpec: CoreSpec) { val classInfo = ctx.getClassInfo(className) // generate vtable type, this should be done for both abstract and concrete classes - val vtableTypeName = genVTableType(classInfo) + val vtableTypeName = genVTableType(clazz, classInfo) val isAbstractClass = !clazz.hasDirectInstances @@ -367,6 +367,7 @@ class ClassEmitter(coreSpec: CoreSpec) { } private def genVTableType( + clazz: LinkedClass, classInfo: ClassInfo )(implicit ctx: WasmContext): wanme.TypeID = { val className = classInfo.name @@ -380,9 +381,9 @@ class ClassEmitter(coreSpec: CoreSpec) { isMutable = false ) } - val superType = classInfo.superClass match { + val superType = clazz.superClass match { case None => genTypeID.typeData - case Some(s) => genTypeID.forVTable(s) + case Some(s) => genTypeID.forVTable(s.name) } val structType = watpe.StructType(CoreWasmLib.typeDataStructFields ::: vtableFields) val subType = watpe.SubType( diff --git a/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/Emitter.scala b/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/Emitter.scala index fc5abbe..e45d16d 100644 --- a/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/Emitter.scala +++ b/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/Emitter.scala @@ -134,7 +134,7 @@ final class Emitter(config: Emitter.Config) { val classInfo = ctx.getClassInfo(className) if (classInfo.classImplementsAnyInterface) { - val interfaces = classInfo.ancestors.map(ctx.getClassInfo(_)).filter(_.isInterface) + val interfaces = clazz.ancestors.map(ctx.getClassInfo(_)).filter(_.isInterface) val resolvedMethodInfos = classInfo.resolvedMethodInfos interfaces.foreach { iface => diff --git a/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/Preprocessor.scala b/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/Preprocessor.scala index d4e0d9a..3298b68 100644 --- a/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/Preprocessor.scala +++ b/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/Preprocessor.scala @@ -126,7 +126,6 @@ object Preprocessor { classConcretePublicMethodNames, allFieldDefs, clazz.superClass.map(_.name), - clazz.interfaces.map(_.name), clazz.ancestors, clazz.hasInstances, !clazz.hasDirectInstances, @@ -282,8 +281,8 @@ object Preprocessor { nextIdx += 1 new Bucket(idx) } - def getAllInterfaces(info: ClassInfo): List[ClassName] = - info.ancestors.filter(ctx.getClassInfo(_).isInterface) + def getAllInterfaces(clazz: LinkedClass): List[ClassName] = + clazz.ancestors.filter(ctx.getClassInfo(_).isInterface) val buckets = new mutable.ListBuffer[Bucket]() @@ -297,7 +296,7 @@ object Preprocessor { for (clazz <- classes.reverseIterator) { val info = ctx.getClassInfo(clazz.name.name) - val ifaces = getAllInterfaces(info) + val ifaces = getAllInterfaces(clazz) if (ifaces.nonEmpty) { val joins = joinsOf.getOrElse(clazz.name.name, new mutable.HashSet()) @@ -336,7 +335,7 @@ object Preprocessor { for (clazz <- classes) { val info = ctx.getClassInfo(clazz.name.name) - val ifaces = getAllInterfaces(info) + val ifaces = getAllInterfaces(clazz) if (ifaces.nonEmpty && !spines.contains(clazz.name.name)) { val used = usedOf.getOrElse(clazz.name.name, new mutable.HashSet()) for { diff --git a/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala b/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala index 38e4c2a..fec6cf9 100644 --- a/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala +++ b/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala @@ -261,9 +261,8 @@ object WasmContext { val jsClassCaptures: Option[List[ParamDef]], classConcretePublicMethodNames: List[MethodName], val allFieldDefs: List[FieldDef], - val superClass: Option[ClassName], - val interfaces: List[ClassName], - val ancestors: List[ClassName], + superClass: Option[ClassName], + ancestors: List[ClassName], private var _hasInstances: Boolean, val isAbstract: Boolean, val hasRuntimeTypeInfo: Boolean, From 39a1336a3d92e3749a3f1463f079007def694cb7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Doeraene?= Date: Tue, 21 May 2024 11:08:36 +0200 Subject: [PATCH 07/11] Remove the `ctx` parameter to `ClassInfo`. Compute the two internal information that needed it from `Preprocessor` instead. --- .../backend/wasmemitter/Preprocessor.scala | 19 +++++++++++++------ .../backend/wasmemitter/WasmContext.scala | 16 ++++------------ 2 files changed, 17 insertions(+), 18 deletions(-) diff --git a/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/Preprocessor.scala b/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/Preprocessor.scala index 3298b68..26173c4 100644 --- a/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/Preprocessor.scala +++ b/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/Preprocessor.scala @@ -67,6 +67,7 @@ object Preprocessor { private def preprocess(clazz: LinkedClass, staticFieldMirrors: Map[FieldName, List[String]])( implicit ctx: WasmContext ): Unit = { + val className = clazz.className val kind = clazz.kind val allFieldDefs: List[FieldDef] = @@ -79,7 +80,7 @@ object Preprocessor { case fd: FieldDef if !fd.flags.namespace.isStatic => fd case fd: JSFieldDef => - throw new AssertionError(s"Illegal $fd in Scala class ${clazz.className}") + throw new AssertionError(s"Illegal $fd in Scala class $className") } inheritedFields ::: myFieldDefs } else { @@ -99,6 +100,13 @@ object Preprocessor { } } + val superClass = clazz.superClass.map(sup => ctx.getClassInfo(sup.name)) + + // Does this Scala class implement any interface? + val classImplementsAnyInterface = + if (!kind.isClass && kind != ClassKind.HijackedClass) false + else clazz.ancestors.exists(a => a != className && ctx.getClassInfo(a).isInterface) + /* Should we emit a vtable/typeData global for this class? * * There are essentially three reasons for which we need them: @@ -117,16 +125,15 @@ object Preprocessor { val hasRuntimeTypeInfo = clazz.hasRuntimeTypeInfo || clazz.hasInstanceTests ctx.putClassInfo( - clazz.name.name, + className, new ClassInfo( - ctx, - clazz.name.name, + className, kind, clazz.jsClassCaptures, classConcretePublicMethodNames, allFieldDefs, - clazz.superClass.map(_.name), - clazz.ancestors, + superClass, + classImplementsAnyInterface, clazz.hasInstances, !clazz.hasDirectInstances, hasRuntimeTypeInfo, diff --git a/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala b/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala index fec6cf9..5a47f31 100644 --- a/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala +++ b/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala @@ -255,14 +255,13 @@ object WasmContext { final case class StringData(constantStringIndex: Int, offset: Int) final class ClassInfo( - ctx: WasmContext, val name: ClassName, val kind: ClassKind, val jsClassCaptures: Option[List[ParamDef]], classConcretePublicMethodNames: List[MethodName], val allFieldDefs: List[FieldDef], - superClass: Option[ClassName], - ancestors: List[ClassName], + superClass: Option[ClassInfo], + val classImplementsAnyInterface: Boolean, private var _hasInstances: Boolean, val isAbstract: Boolean, val hasRuntimeTypeInfo: Boolean, @@ -271,16 +270,10 @@ object WasmContext { val staticFieldMirrors: Map[FieldName, List[String]], private var _itableIdx: Int ) { - - /** Does this Scala class implement any interface? */ - val classImplementsAnyInterface = - if (!kind.isClass && kind != ClassKind.HijackedClass) false - else ancestors.exists(a => a != name && ctx.getClassInfo(a).isInterface) - val resolvedMethodInfos: Map[MethodName, ConcreteMethodInfo] = { if (kind.isClass || kind == ClassKind.HijackedClass) { val inherited: Map[MethodName, ConcreteMethodInfo] = superClass match { - case Some(superClass) => ctx.getClassInfo(superClass).resolvedMethodInfos + case Some(superClass) => superClass.resolvedMethodInfos case None => Map.empty } @@ -365,8 +358,7 @@ object WasmContext { kind match { case ClassKind.Class | ClassKind.ModuleClass | ClassKind.HijackedClass => - val superTableEntries = - superClass.fold[List[MethodName]](Nil)(sup => ctx.getClassInfo(sup).tableEntries) + val superTableEntries = superClass.fold[List[MethodName]](Nil)(_.tableEntries) val superTableEntrySet = superTableEntries.toSet /* When computing the table entries to add for this class, exclude: From 48902b25f545c7364b80827a7e5261337bc6d77a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Doeraene?= Date: Tue, 21 May 2024 11:33:45 +0200 Subject: [PATCH 08/11] Remove state computed at startup from WasmContext. Instead, keep the stateful computations inside `Preprocessor`, and freeze the state when constructing `WasmContext`. This means that it is the responsibility of `Preprocessor` to create the instance of `WasmContext`, instead of `Emitter`. --- .../linker/backend/wasmemitter/Emitter.scala | 5 +- .../backend/wasmemitter/Preprocessor.scala | 74 ++++++++++++------- .../backend/wasmemitter/WasmContext.scala | 26 ++----- 3 files changed, 53 insertions(+), 52 deletions(-) diff --git a/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/Emitter.scala b/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/Emitter.scala index e45d16d..a9084fb 100644 --- a/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/Emitter.scala +++ b/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/Emitter.scala @@ -32,8 +32,6 @@ final class Emitter(config: Emitter.Config) { def injectedIRFiles: Seq[IRFile] = Nil def emit(module: ModuleSet.Module, logger: Logger): Result = { - implicit val ctx: WasmContext = new WasmContext() - /* Sort by ancestor count so that superclasses always appear before * subclasses, then tie-break by name for stability. */ @@ -43,7 +41,8 @@ final class Emitter(config: Emitter.Config) { else a.className.compareTo(b.className) < 0 } - Preprocessor.preprocess(sortedClasses, module.topLevelExports) + implicit val ctx: WasmContext = + Preprocessor.preprocess(sortedClasses, module.topLevelExports) CoreWasmLib.genPreClasses() sortedClasses.foreach { clazz => diff --git a/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/Preprocessor.scala b/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/Preprocessor.scala index 26173c4..9a209df 100644 --- a/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/Preprocessor.scala +++ b/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/Preprocessor.scala @@ -13,15 +13,19 @@ import EmbeddedConstants._ import WasmContext._ object Preprocessor { - def preprocess(classes: List[LinkedClass], tles: List[LinkedTopLevelExport])(implicit - ctx: WasmContext - ): Unit = { + def preprocess(classes: List[LinkedClass], tles: List[LinkedTopLevelExport]): WasmContext = { val staticFieldMirrors = computeStaticFieldMirrors(tles) + val classInfosBuilder = mutable.HashMap.empty[ClassName, ClassInfo] val definedReflectiveProxyNames = mutable.HashSet.empty[MethodName] for (clazz <- classes) { - preprocess(clazz, staticFieldMirrors.getOrElse(clazz.className, Map.empty)) + val classInfo = preprocess( + clazz, + staticFieldMirrors.getOrElse(clazz.className, Map.empty), + classInfosBuilder + ) + classInfosBuilder += clazz.className -> classInfo // For Scala classes, collect the reflective proxy method names that it defines if (clazz.kind.isClass || clazz.kind == ClassKind.HijackedClass) { @@ -30,19 +34,23 @@ object Preprocessor { } } + val classInfos = classInfosBuilder.toMap + // sort for stability - ctx.setReflectiveProxyIDs(definedReflectiveProxyNames.toList.sorted.zipWithIndex.toMap) + val reflectiveProxyIDs = definedReflectiveProxyNames.toList.sorted.zipWithIndex.toMap - val collector = new AbstractMethodCallCollector(ctx) + val collector = new AbstractMethodCallCollector(classInfos) for (clazz <- classes) collector.collectAbstractMethodCalls(clazz) for (tle <- tles) collector.collectAbstractMethodCalls(tle) for (clazz <- classes) { - ctx.getClassInfo(clazz.className).buildMethodTable() + classInfos(clazz.className).buildMethodTable() } - assignBuckets(classes) + val itablesLength = assignBuckets(classes, classInfos) + + new WasmContext(classInfos, reflectiveProxyIDs, itablesLength) } private def computeStaticFieldMirrors( @@ -64,9 +72,11 @@ object Preprocessor { result } - private def preprocess(clazz: LinkedClass, staticFieldMirrors: Map[FieldName, List[String]])( - implicit ctx: WasmContext - ): Unit = { + private def preprocess( + clazz: LinkedClass, + staticFieldMirrors: Map[FieldName, List[String]], + previousClassInfos: collection.Map[ClassName, ClassInfo] + ): ClassInfo = { val className = clazz.className val kind = clazz.kind @@ -74,7 +84,7 @@ object Preprocessor { if (kind.isClass) { val inheritedFields = clazz.superClass match { case None => Nil - case Some(sup) => ctx.getClassInfo(sup.name).allFieldDefs + case Some(sup) => previousClassInfos(sup.name).allFieldDefs } val myFieldDefs = clazz.fields.collect { case fd: FieldDef if !fd.flags.namespace.isStatic => @@ -100,12 +110,15 @@ object Preprocessor { } } - val superClass = clazz.superClass.map(sup => ctx.getClassInfo(sup.name)) + val superClass = clazz.superClass.map(sup => previousClassInfos(sup.name)) + + val strictClassAncestors = + if (kind.isClass || kind == ClassKind.HijackedClass) clazz.ancestors.tail + else Nil // Does this Scala class implement any interface? val classImplementsAnyInterface = - if (!kind.isClass && kind != ClassKind.HijackedClass) false - else clazz.ancestors.exists(a => a != className && ctx.getClassInfo(a).isInterface) + strictClassAncestors.exists(a => previousClassInfos(a).isInterface) /* Should we emit a vtable/typeData global for this class? * @@ -124,8 +137,7 @@ object Preprocessor { */ val hasRuntimeTypeInfo = clazz.hasRuntimeTypeInfo || clazz.hasInstanceTests - ctx.putClassInfo( - className, + val classInfo = { new ClassInfo( className, kind, @@ -142,12 +154,12 @@ object Preprocessor { staticFieldMirrors, _itableIdx = -1 ) - ) + } // Update specialInstanceTypes for ancestors of hijacked classes if (clazz.kind == ClassKind.HijackedClass) { def addSpecialInstanceTypeOnAllAncestors(jsValueType: Int): Unit = - clazz.ancestors.foreach(ctx.getClassInfo(_).addSpecialInstanceType(jsValueType)) + strictClassAncestors.foreach(previousClassInfos(_).addSpecialInstanceType(jsValueType)) clazz.className match { case BoxedBooleanClass => @@ -168,7 +180,9 @@ object Preprocessor { * Manually mark all ancestors of instantiated classes as having instances. */ if (clazz.hasDirectInstances && !kind.isJSType) - clazz.ancestors.foreach(ancestor => ctx.getClassInfo(ancestor).setHasInstances()) + strictClassAncestors.foreach(ancestor => previousClassInfos(ancestor).setHasInstances()) + + classInfo } /** Collect FunctionInfo based on the abstract method call @@ -191,7 +205,8 @@ object Preprocessor { * It keeps B.c because it's concrete and used. But because `C.c` isn't there at all anymore, if * we have val `x: C` and we call `x.c`, we don't find the method at all. */ - private class AbstractMethodCallCollector(ctx: WasmContext) extends Traversers.Traverser { + private class AbstractMethodCallCollector(classInfos: Map[ClassName, ClassInfo]) + extends Traversers.Traverser { def collectAbstractMethodCalls(clazz: LinkedClass): Unit = { for (method <- clazz.methods) traverseMethodDef(method) @@ -217,11 +232,11 @@ object Preprocessor { case Apply(flags, receiver, methodName, _) if !methodName.name.isReflectiveProxy => receiver.tpe match { case ClassType(className) => - val classInfo = ctx.getClassInfo(className) + val classInfo = classInfos(className) if (classInfo.hasInstances) classInfo.registerDynamicCall(methodName.name) case AnyType => - ctx.getClassInfo(ObjectClass).registerDynamicCall(methodName.name) + classInfos(ObjectClass).registerDynamicCall(methodName.name) case _ => // For all other cases, including arrays, we will always perform a static dispatch () @@ -279,7 +294,10 @@ object Preprocessor { * Inclusion Tests" * [[https://www.researchgate.net/publication/2438441_Efficient_Type_Inclusion_Tests]] */ - private def assignBuckets(allClasses: List[LinkedClass])(implicit ctx: WasmContext): Unit = { + private def assignBuckets( + allClasses: List[LinkedClass], + classInfos: Map[ClassName, ClassInfo] + ): Int = { val classes = allClasses.filterNot(_.kind.isJSType) var nextIdx = 0 @@ -289,7 +307,7 @@ object Preprocessor { new Bucket(idx) } def getAllInterfaces(clazz: LinkedClass): List[ClassName] = - clazz.ancestors.filter(ctx.getClassInfo(_).isInterface) + clazz.ancestors.filter(classInfos(_).isInterface) val buckets = new mutable.ListBuffer[Bucket]() @@ -302,7 +320,7 @@ object Preprocessor { val spines = new mutable.HashSet[ClassName]() for (clazz <- classes.reverseIterator) { - val info = ctx.getClassInfo(clazz.name.name) + val info = classInfos(clazz.name.name) val ifaces = getAllInterfaces(clazz) if (ifaces.nonEmpty) { val joins = joinsOf.getOrElse(clazz.name.name, new mutable.HashSet()) @@ -341,7 +359,7 @@ object Preprocessor { } for (clazz <- classes) { - val info = ctx.getClassInfo(clazz.name.name) + val info = classInfos(clazz.name.name) val ifaces = getAllInterfaces(clazz) if (ifaces.nonEmpty && !spines.contains(clazz.name.name)) { val used = usedOf.getOrElse(clazz.name.name, new mutable.HashSet()) @@ -369,7 +387,7 @@ object Preprocessor { } } - ctx.setItablesLength(buckets.length) + buckets.length } private final class Bucket(idx: Int) { diff --git a/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala b/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala index 5a47f31..050ce05 100644 --- a/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala +++ b/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala @@ -25,22 +25,13 @@ import org.scalajs.linker.backend.webassembly.{Types => watpe} import VarGen._ import org.scalajs.ir.OriginalName -final class WasmContext { +final class WasmContext( + classInfo: Map[ClassName, WasmContext.ClassInfo], + reflectiveProxies: Map[MethodName, Int], + val itablesLength: Int +) { import WasmContext._ - private val classInfo = mutable.Map[ClassName, ClassInfo]() - private var reflectiveProxies: Map[MethodName, Int] = null - private var _itablesLength: Int = -1 - - /** Sets the length of the itables arrays, only for use by `Preprocessor`. */ - def setItablesLength(length: Int): Unit = - _itablesLength = length - - def itablesLength: Int = { - require(_itablesLength != -1, s"itablesLength was not yet assigned") - _itablesLength - } - private val functionTypes = LinkedHashMap.empty[watpe.FunctionType, wanme.TypeID] private val tableFunctionTypes = mutable.HashMap.empty[MethodName, wanme.TypeID] private val constantStringGlobals = LinkedHashMap.empty[String, StringData] @@ -102,10 +93,6 @@ final class WasmContext { ArrayType(typeRef) } - /** Sets the map of reflexity proxy IDs, only for use by `Preprocessor`. */ - def setReflectiveProxyIDs(proxyIDs: Map[MethodName, Int]): Unit = - reflectiveProxies = proxyIDs - /** Retrieves a unique identifier for a reflective proxy with the given name. * * If no class defines a reflective proxy with the given name, returns `-1`. @@ -235,9 +222,6 @@ final class WasmContext { def addFuncDeclaration(name: wanme.FunctionID): Unit = _funcDeclarations += name - def putClassInfo(name: ClassName, info: ClassInfo): Unit = - classInfo.put(name, info) - def addJSPrivateFieldName(fieldName: FieldName): Unit = _jsPrivateFieldNames += fieldName From 772cac03a47b2261978ae132d256aa5b88d876c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Doeraene?= Date: Tue, 21 May 2024 11:52:11 +0200 Subject: [PATCH 09/11] Use the imported module list from the `ModuleSet`. We already have the set of all imported modules. There is no need to compute it while generating code. --- .../linker/backend/wasmemitter/Emitter.scala | 18 ++++++++++++++++- .../linker/backend/wasmemitter/SWasmGen.scala | 2 +- .../backend/wasmemitter/WasmContext.scala | 20 ------------------- 3 files changed, 18 insertions(+), 22 deletions(-) diff --git a/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/Emitter.scala b/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/Emitter.scala index a9084fb..6a07178 100644 --- a/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/Emitter.scala +++ b/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/Emitter.scala @@ -44,6 +44,22 @@ final class Emitter(config: Emitter.Config) { implicit val ctx: WasmContext = Preprocessor.preprocess(sortedClasses, module.topLevelExports) + // Sort for stability + val allImportedModules: List[String] = module.externalDependencies.toList.sorted + + // Gen imports of external modules on the Wasm side + for (moduleName <- allImportedModules) { + val id = genGlobalID.forImportedModule(moduleName) + val origName = OriginalName("import." + moduleName) + ctx.moduleBuilder.addImport( + wamod.Import( + "__scalaJSImports", + moduleName, + wamod.ImportDesc.Global(id, origName, watpe.RefType.anyref, isMutable = false) + ) + ) + } + CoreWasmLib.genPreClasses() sortedClasses.foreach { clazz => classEmitter.genClassDef(clazz) @@ -63,7 +79,7 @@ final class Emitter(config: Emitter.Config) { val loaderContent = LoaderContent.bytesContent val jsFileContent = - buildJSFileContent(module, module.id.id + ".wasm", ctx.allImportedModules) + buildJSFileContent(module, module.id.id + ".wasm", allImportedModules) new Result(wasmModule, loaderContent, jsFileContent) } diff --git a/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/SWasmGen.scala b/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/SWasmGen.scala index 620c236..af7a319 100644 --- a/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/SWasmGen.scala +++ b/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/SWasmGen.scala @@ -80,7 +80,7 @@ object SWasmGen { fb += Call(genFunctionID.jsGlobalRefGet) genFollowPath(path) case JSNativeLoadSpec.Import(module, path) => - fb += GlobalGet(ctx.getImportedModuleGlobal(module)) + fb += GlobalGet(genGlobalID.forImportedModule(module)) genFollowPath(path) case JSNativeLoadSpec.ImportWithGlobalFallback(importSpec, globalSpec) => genLoadJSFromSpec(fb, importSpec) diff --git a/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala b/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala index 050ce05..1736a3e 100644 --- a/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala +++ b/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala @@ -55,9 +55,6 @@ final class WasmContext( private var nextConstantStringIndex: Int = 0 private var nextClosureDataTypeIndex: Int = 1 - private val _importedModules: mutable.LinkedHashSet[String] = - new mutable.LinkedHashSet() - private val _jsPrivateFieldNames: mutable.ListBuffer[FieldName] = new mutable.ListBuffer() private val _funcDeclarations: mutable.LinkedHashSet[wanme.FunctionID] = @@ -202,23 +199,6 @@ final class WasmContext( def addGlobal(g: wamod.Global): Unit = moduleBuilder.addGlobal(g) - def getImportedModuleGlobal(moduleName: String): wanme.GlobalID = { - val name = genGlobalID.forImportedModule(moduleName) - if (_importedModules.add(moduleName)) { - val origName = OriginalName("import." + moduleName) - moduleBuilder.addImport( - wamod.Import( - "__scalaJSImports", - moduleName, - wamod.ImportDesc.Global(name, origName, watpe.RefType.anyref, isMutable = false) - ) - ) - } - name - } - - def allImportedModules: List[String] = _importedModules.toList - def addFuncDeclaration(name: wanme.FunctionID): Unit = _funcDeclarations += name From 176e9820e61adaafed588f819a030600bd33d6ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Doeraene?= Date: Tue, 21 May 2024 11:56:34 +0200 Subject: [PATCH 10/11] Do not store the set of JS private fields in `WasmContext`. We can recompute it in `genStartFunction` instead. --- .../linker/backend/wasmemitter/ClassEmitter.scala | 1 - .../linker/backend/wasmemitter/Emitter.scala | 13 ++++++++++--- .../linker/backend/wasmemitter/WasmContext.scala | 8 -------- 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/ClassEmitter.scala b/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/ClassEmitter.scala index bfaad56..648d677 100644 --- a/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/ClassEmitter.scala +++ b/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/ClassEmitter.scala @@ -678,7 +678,6 @@ class ClassEmitter(coreSpec: CoreSpec) { isMutable = true ) ) - ctx.addJSPrivateFieldName(name.name) case _ => () } diff --git a/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/Emitter.scala b/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/Emitter.scala index 6a07178..c4be761 100644 --- a/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/Emitter.scala +++ b/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/Emitter.scala @@ -187,9 +187,16 @@ final class Emitter(config: Emitter.Config) { // Initialize the JS private field symbols - for (fieldName <- ctx.getAllJSPrivateFieldNames()) { - instrs += wa.Call(genFunctionID.newSymbol) - instrs += wa.GlobalSet(genGlobalID.forJSPrivateField(fieldName)) + for (clazz <- sortedClasses if clazz.kind.isJSClass) { + for (fieldDef <- clazz.fields) { + fieldDef match { + case FieldDef(flags, name, _, _) if !flags.namespace.isStatic => + instrs += wa.Call(genFunctionID.newSymbol) + instrs += wa.GlobalSet(genGlobalID.forJSPrivateField(name.name)) + case _ => + () + } + } } // Emit the static initializers diff --git a/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala b/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala index 1736a3e..5b2c813 100644 --- a/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala +++ b/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala @@ -55,8 +55,6 @@ final class WasmContext( private var nextConstantStringIndex: Int = 0 private var nextClosureDataTypeIndex: Int = 1 - private val _jsPrivateFieldNames: mutable.ListBuffer[FieldName] = - new mutable.ListBuffer() private val _funcDeclarations: mutable.LinkedHashSet[wanme.FunctionID] = new mutable.LinkedHashSet() @@ -202,15 +200,9 @@ final class WasmContext( def addFuncDeclaration(name: wanme.FunctionID): Unit = _funcDeclarations += name - def addJSPrivateFieldName(fieldName: FieldName): Unit = - _jsPrivateFieldNames += fieldName - def getFinalStringPool(): (Array[Byte], Int) = (stringPool.toArray, nextConstantStringIndex) - def getAllJSPrivateFieldNames(): List[FieldName] = - _jsPrivateFieldNames.toList - def getAllFuncDeclarations(): List[wanme.FunctionID] = _funcDeclarations.toList } From 892fdc55ae71b6d1aed269b927fc1910d3c277ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Doeraene?= Date: Tue, 21 May 2024 11:59:24 +0200 Subject: [PATCH 11/11] Remove some dead code methods in `WasmContext`. --- .../linker/backend/wasmemitter/WasmContext.scala | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala b/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala index 5b2c813..d67e8eb 100644 --- a/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala +++ b/wasm/src/main/scala/org/scalajs/linker/backend/wasmemitter/WasmContext.scala @@ -184,22 +184,13 @@ final class WasmContext( } def refFuncWithDeclaration(name: wanme.FunctionID): wa.RefFunc = { - addFuncDeclaration(name) + _funcDeclarations += name wa.RefFunc(name) } - def addExport(exprt: wamod.Export): Unit = - moduleBuilder.addExport(exprt) - - def addFunction(fun: wamod.Function): Unit = - moduleBuilder.addFunction(fun) - def addGlobal(g: wamod.Global): Unit = moduleBuilder.addGlobal(g) - def addFuncDeclaration(name: wanme.FunctionID): Unit = - _funcDeclarations += name - def getFinalStringPool(): (Array[Byte], Int) = (stringPool.toArray, nextConstantStringIndex)