Skip to content
This repository has been archived by the owner on Jul 12, 2024. It is now read-only.

Remove most of the state from WasmContext. #138

Merged
merged 11 commits into from
May 22, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ class ClassEmitter(coreSpec: CoreSpec) {
}

val strictAncestorsValue: List[wa.Instr] = {
val ancestors = ctx.getClassInfo(className).ancestors
val ancestors = clazz.ancestors

// By spec, the first element of `ancestors` is always the class itself
assert(
Expand All @@ -205,7 +205,7 @@ class ClassEmitter(coreSpec: CoreSpec) {
val cloneFunction = {
// If the class is concrete and implements the `java.lang.Cloneable`,
// `genCloneFunction` should've generated the clone function
if (!classInfo.isAbstract && classInfo.ancestors.contains(CloneableClass))
if (!classInfo.isAbstract && clazz.ancestors.contains(CloneableClass))
wa.RefFunc(genFunctionID.clone(className))
else
wa.RefNull(watpe.HeapType.NoFunc)
Expand Down Expand Up @@ -283,7 +283,7 @@ class ClassEmitter(coreSpec: CoreSpec) {
val classInfo = ctx.getClassInfo(className)

// generate vtable type, this should be done for both abstract and concrete classes
val vtableTypeName = genVTableType(classInfo)
val vtableTypeName = genVTableType(clazz, classInfo)

val isAbstractClass = !clazz.hasDirectInstances

Expand Down Expand Up @@ -367,6 +367,7 @@ class ClassEmitter(coreSpec: CoreSpec) {
}

private def genVTableType(
clazz: LinkedClass,
classInfo: ClassInfo
)(implicit ctx: WasmContext): wanme.TypeID = {
val className = classInfo.name
Expand All @@ -380,9 +381,9 @@ class ClassEmitter(coreSpec: CoreSpec) {
isMutable = false
)
}
val superType = classInfo.superClass match {
val superType = clazz.superClass match {
case None => genTypeID.typeData
case Some(s) => genTypeID.forVTable(s)
case Some(s) => genTypeID.forVTable(s.name)
}
val structType = watpe.StructType(CoreWasmLib.typeDataStructFields ::: vtableFields)
val subType = watpe.SubType(
Expand Down Expand Up @@ -510,8 +511,7 @@ class ClassEmitter(coreSpec: CoreSpec) {

instrs += wa.GlobalGet(genGlobalID.forVTable(className))

val interfaces = classInfo.ancestors.map(ctx.getClassInfo(_)).filter(_.isInterface)
if (!interfaces.isEmpty)
if (classInfo.classImplementsAnyInterface)
instrs += wa.GlobalGet(genGlobalID.forITable(className))
else
instrs += wa.RefNull(watpe.HeapType(genTypeID.itables))
Expand Down Expand Up @@ -618,9 +618,8 @@ class ClassEmitter(coreSpec: CoreSpec) {
*/
private def genGlobalClassItable(clazz: LinkedClass)(implicit ctx: WasmContext): Unit = {
val className = clazz.className
val info = ctx.getClassInfo(className)
val implementsAnyInterface = info.ancestors.exists(a => ctx.getClassInfo(a).isInterface)
if (implementsAnyInterface) {

if (ctx.getClassInfo(className).classImplementsAnyInterface) {
val globalName = genGlobalID.forITable(className)
val itablesInit = List(
wa.I32Const(ctx.itablesLength),
Expand All @@ -633,7 +632,7 @@ class ClassEmitter(coreSpec: CoreSpec) {
wa.Expr(itablesInit),
isMutable = false
)
ctx.addGlobalITable(className, global)
ctx.addGlobal(global)
}
}

Expand Down Expand Up @@ -679,7 +678,6 @@ class ClassEmitter(coreSpec: CoreSpec) {
isMutable = true
)
)
ctx.addJSPrivateFieldName(name.name)
case _ =>
()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ final class Emitter(config: Emitter.Config) {
def injectedIRFiles: Seq[IRFile] = Nil

def emit(module: ModuleSet.Module, logger: Logger): Result = {
implicit val ctx: WasmContext = new WasmContext()

/* Sort by ancestor count so that superclasses always appear before
* subclasses, then tie-break by name for stability.
*/
Expand All @@ -43,7 +41,24 @@ final class Emitter(config: Emitter.Config) {
else a.className.compareTo(b.className) < 0
}

Preprocessor.preprocess(sortedClasses, module.topLevelExports)
implicit val ctx: WasmContext =
Preprocessor.preprocess(sortedClasses, module.topLevelExports)

// Sort for stability
val allImportedModules: List[String] = module.externalDependencies.toList.sorted

// Gen imports of external modules on the Wasm side
for (moduleName <- allImportedModules) {
val id = genGlobalID.forImportedModule(moduleName)
val origName = OriginalName("import." + moduleName)
ctx.moduleBuilder.addImport(
wamod.Import(
"__scalaJSImports",
moduleName,
wamod.ImportDesc.Global(id, origName, watpe.RefType.anyref, isMutable = false)
)
)
}

CoreWasmLib.genPreClasses()
sortedClasses.foreach { clazz =>
Expand All @@ -54,27 +69,24 @@ final class Emitter(config: Emitter.Config) {
}
CoreWasmLib.genPostClasses()

val classesWithStaticInit =
sortedClasses.filter(_.hasStaticInitializer).map(_.className)

complete(
sortedClasses,
module.initializers.toList,
classesWithStaticInit,
module.topLevelExports
)

val wasmModule = ctx.moduleBuilder.build()

val loaderContent = LoaderContent.bytesContent
val jsFileContent =
buildJSFileContent(module, module.id.id + ".wasm", ctx.allImportedModules)
buildJSFileContent(module, module.id.id + ".wasm", allImportedModules)

new Result(wasmModule, loaderContent, jsFileContent)
}

private def complete(
sortedClasses: List[LinkedClass],
moduleInitializers: List[ModuleInitializer.Initializer],
classesWithStaticInit: List[ClassName],
topLevelExportDefs: List[LinkedTopLevelExport]
)(implicit ctx: WasmContext): Unit = {
/* Before generating the string globals in `genStartFunction()`, make sure
Expand Down Expand Up @@ -114,13 +126,13 @@ final class Emitter(config: Emitter.Config) {
)
)

genStartFunction(moduleInitializers, classesWithStaticInit, topLevelExportDefs)
genStartFunction(sortedClasses, moduleInitializers, topLevelExportDefs)
genDeclarativeElements()
}

private def genStartFunction(
sortedClasses: List[LinkedClass],
moduleInitializers: List[ModuleInitializer.Initializer],
classesWithStaticInit: List[ClassName],
topLevelExportDefs: List[LinkedTopLevelExport]
)(implicit ctx: WasmContext): Unit = {
import org.scalajs.ir.Trees._
Expand All @@ -132,20 +144,24 @@ final class Emitter(config: Emitter.Config) {
val instrs: fb.type = fb

// Initialize itables
for (className <- ctx.getAllClassesWithITableGlobal()) {
for (clazz <- sortedClasses if clazz.kind.isClass && clazz.hasDirectInstances) {
val className = clazz.className
val classInfo = ctx.getClassInfo(className)
val interfaces = classInfo.ancestors.map(ctx.getClassInfo(_)).filter(_.isInterface)
val resolvedMethodInfos = classInfo.resolvedMethodInfos

interfaces.foreach { iface =>
val idx = ctx.getItableIdx(iface)
instrs += wa.GlobalGet(genGlobalID.forITable(className))
instrs += wa.I32Const(idx)
if (classInfo.classImplementsAnyInterface) {
val interfaces = clazz.ancestors.map(ctx.getClassInfo(_)).filter(_.isInterface)
val resolvedMethodInfos = classInfo.resolvedMethodInfos

for (method <- iface.tableEntries)
instrs += ctx.refFuncWithDeclaration(resolvedMethodInfos(method).tableEntryName)
instrs += wa.StructNew(genTypeID.forITable(iface.name))
instrs += wa.ArraySet(genTypeID.itables)
interfaces.foreach { iface =>
val idx = ctx.getItableIdx(iface)
instrs += wa.GlobalGet(genGlobalID.forITable(className))
instrs += wa.I32Const(idx)

for (method <- iface.tableEntries)
instrs += ctx.refFuncWithDeclaration(resolvedMethodInfos(method).tableEntryName)
instrs += wa.StructNew(genTypeID.forITable(iface.name))
instrs += wa.ArraySet(genTypeID.itables)
}
}
}

Expand All @@ -171,17 +187,24 @@ final class Emitter(config: Emitter.Config) {

// Initialize the JS private field symbols

for (fieldName <- ctx.getAllJSPrivateFieldNames()) {
instrs += wa.Call(genFunctionID.newSymbol)
instrs += wa.GlobalSet(genGlobalID.forJSPrivateField(fieldName))
for (clazz <- sortedClasses if clazz.kind.isJSClass) {
for (fieldDef <- clazz.fields) {
fieldDef match {
case FieldDef(flags, name, _, _) if !flags.namespace.isStatic =>
instrs += wa.Call(genFunctionID.newSymbol)
instrs += wa.GlobalSet(genGlobalID.forJSPrivateField(name.name))
case _ =>
()
}
}
}

// Emit the static initializers

for (className <- classesWithStaticInit) {
for (clazz <- sortedClasses if clazz.hasStaticInitializer) {
val funcName = genFunctionID.forMethod(
MemberNamespace.StaticConstructor,
className,
clazz.className,
StaticInitializerName
)
instrs += wa.Call(funcName)
Expand Down
Loading