diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index e74c9a9da..c621d3279 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -33,7 +33,7 @@ jobs: fail-fast: false matrix: java-version: [8, 11] - scala-version: [2.12.18, 2.13.12, 3.2.2, 3.3.1] + scala-version: [2.12.18, 2.13.12, 3.2.2, 3.3.3] runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 @@ -47,7 +47,7 @@ jobs: strategy: fail-fast: false matrix: - scala-version: [2.12.18, 2.13.12, 3.2.2] + scala-version: [2.12.18, 2.13.12, 3.3.3] runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 diff --git a/amm/compiler/src/main/scala-3/ammonite/compiler/AmmonitePhase.scala b/amm/compiler/src/main/scala-3.0.0-3.3.1/ammonite/compiler/AmmonitePhase.scala similarity index 100% rename from amm/compiler/src/main/scala-3/ammonite/compiler/AmmonitePhase.scala rename to amm/compiler/src/main/scala-3.0.0-3.3.1/ammonite/compiler/AmmonitePhase.scala diff --git a/amm/compiler/src/main/scala-3/ammonite/compiler/Compiler.scala b/amm/compiler/src/main/scala-3.0.0-3.3.1/ammonite/compiler/Compiler.scala similarity index 100% rename from amm/compiler/src/main/scala-3/ammonite/compiler/Compiler.scala rename to amm/compiler/src/main/scala-3.0.0-3.3.1/ammonite/compiler/Compiler.scala diff --git a/amm/compiler/src/main/scala-3/ammonite/compiler/Preprocessor.scala b/amm/compiler/src/main/scala-3.0.0-3.3.1/ammonite/compiler/Preprocessor.scala similarity index 100% rename from amm/compiler/src/main/scala-3/ammonite/compiler/Preprocessor.scala rename to amm/compiler/src/main/scala-3.0.0-3.3.1/ammonite/compiler/Preprocessor.scala diff --git a/amm/compiler/src/main/scala-3/ammonite/compiler/SyntaxHighlighting.scala b/amm/compiler/src/main/scala-3.0.0-3.3.1/ammonite/compiler/SyntaxHighlighting.scala similarity index 100% rename from amm/compiler/src/main/scala-3/ammonite/compiler/SyntaxHighlighting.scala rename to amm/compiler/src/main/scala-3.0.0-3.3.1/ammonite/compiler/SyntaxHighlighting.scala diff --git a/amm/compiler/src/main/scala-3.3.2+/ammonite/compiler/AmmonitePhase.scala b/amm/compiler/src/main/scala-3.3.2+/ammonite/compiler/AmmonitePhase.scala new file mode 100644 index 000000000..333645aa7 --- /dev/null +++ b/amm/compiler/src/main/scala-3.3.2+/ammonite/compiler/AmmonitePhase.scala @@ -0,0 +1,264 @@ +package ammonite.compiler + +import ammonite.util.{ImportData, Imports, Name => AmmName, Printer, Util} + +import dotty.tools.dotc +import dotty.tools.dotc.core.StdNames.nme +import dotc.ast.Trees._ +import dotc.ast.{tpd, untpd} +import dotc.core.Flags +import dotc.core.Contexts._ +import dotc.core.Names.Name +import dotc.core.Phases.Phase +import dotc.core.Symbols.{NoSymbol, Symbol, newSymbol} +import dotc.core.Types.{TermRef, Type, TypeTraverser} + +import scala.collection.mutable + +class AmmonitePhase( + userCodeNestingLevel: => Int, + needsUsedEarlierDefinitions: => Boolean +) extends Phase: + import tpd._ + + def phaseName: String = "ammonite" + + private var myImports = new mutable.ListBuffer[(Boolean, String, String, Seq[AmmName])] + private var usedEarlierDefinitions0 = new mutable.ListBuffer[String] + + def importData: Seq[ImportData] = + val grouped = myImports + .toList + .distinct + .groupBy { case (a, b, c, d) => (b, c, d) } + .mapValues(_.map(_._1)) + + val open = for { + ((fromName, toName, importString), items) <- grouped + if !CompilerUtil.ignoredNames(fromName) + } yield { + val importType = items match{ + case Seq(true) => ImportData.Type + case Seq(false) => ImportData.Term + case Seq(_, _) => ImportData.TermType + } + + ImportData(AmmName(fromName), AmmName(toName), importString, importType) + } + + open.toVector.sortBy(x => Util.encodeScalaSourcePath(x.prefix)) + + def usedEarlierDefinitions: Seq[String] = + usedEarlierDefinitions0.toList.distinct + + private def saneSym(name: Name, sym: Symbol)(using Context): Boolean = + !name.decode.toString.contains('$') && + sym.exists && + // !sym.is(Flags.Synthetic) && + !scala.util.Try(sym.is(Flags.Private)).toOption.getOrElse(true) && + !scala.util.Try(sym.is(Flags.Protected)).toOption.getOrElse(true) && + // sym.is(Flags.Public) && + !CompilerUtil.ignoredSyms(sym.toString) && + !CompilerUtil.ignoredNames(name.decode.toString) + + private def saneSym(sym: Symbol)(using Context): Boolean = + saneSym(sym.name, sym) + + private def processTree(t: tpd.Tree)(using Context): Unit = { + val sym = t.symbol + val name = t match { + case t: tpd.ValDef => t.name + case _ => sym.name + } + if (saneSym(name, sym)) { + val name = sym.name.decode.toString + myImports.addOne((sym.isType, name, name, Nil)) + } + } + + private def processImport(i: tpd.Import)(using Context): Unit = { + val expr = i.expr + val selectors = i.selectors + + // Most of that logic was adapted from AmmonitePlugin, the Scala 2 counterpart + // of this file. + + val prefix = + val (_ :: nameListTail, symbolHead :: _) = { + def rec(expr: tpd.Tree): List[(Name, Symbol)] = { + expr match { + case s @ tpd.Select(lhs, _) => (s.symbol.name -> s.symbol) :: rec(lhs) + case i @ tpd.Ident(name) => List(name -> i.symbol) + case t @ tpd.This(pkg) => List(pkg.name -> t.symbol) + } + } + rec(expr).reverse.unzip + } + + val headFullPath = symbolHead.fullName.decode.toString.split('.') + .map(n => if (n.endsWith("$")) n.stripSuffix("$") else n) // meh + // prefix package imports with `_root_` to try and stop random + // variables from interfering with them. If someone defines a value + // called `_root_`, this will still break, but that's their problem + val rootPrefix = if(symbolHead.denot.is(Flags.Package)) Seq("_root_") else Nil + val tailPath = nameListTail.map(_.decode.toString) + + (rootPrefix ++ headFullPath ++ tailPath).map(AmmName(_)) + + def isMask(sel: untpd.ImportSelector) = sel.name != nme.WILDCARD && sel.rename == nme.WILDCARD + + val renameMap = + + /** + * A map of each name importable from `expr`, to a `Seq[Boolean]` + * containing a `true` if there's a type-symbol you can import, `false` + * if there's a non-type symbol and both if there are both type and + * non-type symbols that are importable for that name + */ + val importableIsTypes = + expr.tpe + .allMembers + .map(_.symbol) + .filter(saneSym(_)) + .groupBy(_.name.decode.toString) + .mapValues(_.map(_.isType).toVector) + + val renamings = for{ + t @ untpd.ImportSelector(name, renameTree, _) <- selectors + if !isMask(t) + // getOrElse just in case... + isType <- importableIsTypes.getOrElse(name.name.decode.toString, Nil) + Ident(rename) <- Option(renameTree) + } yield ((isType, rename.decode.toString), name.name.decode.toString) + + renamings.toMap + + + def isUnimportableUnlessRenamed(sym: Symbol): Boolean = + sym eq NoSymbol + + @scala.annotation.tailrec + def transformImport(selectors: List[untpd.ImportSelector], sym: Symbol): List[Symbol] = + selectors match { + case Nil => Nil + case sel :: Nil if sel.isWildcard => + if (isUnimportableUnlessRenamed(sym)) Nil + else List(sym) + case (sel @ untpd.ImportSelector(from, to, _)) :: _ + if from.name == (if (from.isTerm) sym.name.toTermName else sym.name.toTypeName) => + if (isMask(sel)) Nil + else List( + newSymbol(sym.owner, sel.rename, sym.flags, sym.info, sym.privateWithin, sym.coord) + ) + case _ :: rest => transformImport(rest, sym) + } + + val symNames = + for { + sym <- expr.tpe.allMembers.map(_.symbol).flatMap(transformImport(selectors, _)) + if saneSym(sym) + } yield (sym.isType, sym.name.decode.toString) + + val syms = for { + // For some reason `info.allImportedSymbols` does not show imported + // type aliases when they are imported directly e.g. + // + // import scala.reflect.macros.Context + // + // As opposed to via import scala.reflect.macros._. + // Thus we need to combine allImportedSymbols with the renameMap + (isType, sym) <- (symNames.toList ++ renameMap.keys).distinct + } yield (isType, renameMap.getOrElse((isType, sym), sym), sym, prefix) + + myImports ++= syms + } + + private def updateUsedEarlierDefinitions( + wrapperSym: Symbol, + stats: List[tpd.Tree] + )(using Context): Unit = { + /* + * We list the variables from the first wrapper + * used from the user code. + * + * E.g. if, after wrapping, the code looks like + * ``` + * class cmd2 { + * + * val cmd0 = ??? + * val cmd1 = ??? + * + * import cmd0.{ + * n + * } + * + * class Helper { + * // user-typed code + * val n0 = n + 1 + * } + * } + * ``` + * this would process the tree of `val n0 = n + 1`, find `n` as a tree like + * `cmd2.this.cmd0.n`, and put `cmd0` in `uses`. + */ + + val typeTraverser: TypeTraverser = new TypeTraverser { + def traverse(tpe: Type) = tpe match { + case tr: TermRef if tr.prefix.typeSymbol == wrapperSym => + tr.designator match { + case n: Name => usedEarlierDefinitions0 += n.decode.toString + case s: Symbol => usedEarlierDefinitions0 += s.name.decode.toString + case _ => // can this happen? + } + case _ => + traverseChildren(tpe) + } + } + + val traverser: TreeTraverser = new TreeTraverser { + def traverse(tree: Tree)(using Context) = tree match { + case tpd.Select(node, name) if node.symbol == wrapperSym => + usedEarlierDefinitions0 += name.decode.toString + case tt @ tpd.TypeTree() => + typeTraverser.traverse(tt.tpe) + case _ => + traverseChildren(tree) + } + } + + for (tree <- stats) + traverser.traverse(tree) + } + + private def unpkg(tree: tpd.Tree): List[tpd.Tree] = + tree match { + case PackageDef(_, elems) => elems.flatMap(unpkg) + case _ => List(tree) + } + + def run(using Context): Unit = + val elems = unpkg(ctx.compilationUnit.tpdTree) + def mainStats(trees: List[tpd.Tree]): List[tpd.Tree] = + trees + .reverseIterator + .collectFirst { + case TypeDef(name, rhs0: Template) => rhs0.body + } + .getOrElse(Nil) + + val rootStats = mainStats(elems) + val stats = (1 until userCodeNestingLevel) + .foldLeft(rootStats)((trees, _) => mainStats(trees)) + + if (needsUsedEarlierDefinitions) { + val wrapperSym = elems.last.symbol + updateUsedEarlierDefinitions(wrapperSym, stats) + } + + stats.foreach { + case i: Import => processImport(i) + case t: tpd.DefDef => processTree(t) + case t: tpd.ValDef => processTree(t) + case t: tpd.TypeDef => processTree(t) + case _ => + } diff --git a/amm/compiler/src/main/scala-3.3.2+/ammonite/compiler/Compiler.scala b/amm/compiler/src/main/scala-3.3.2+/ammonite/compiler/Compiler.scala new file mode 100644 index 000000000..5046b8972 --- /dev/null +++ b/amm/compiler/src/main/scala-3.3.2+/ammonite/compiler/Compiler.scala @@ -0,0 +1,523 @@ +package ammonite.compiler + +import java.net.URL +import java.nio.charset.StandardCharsets +import java.nio.file.{Files, Path, Paths} +import java.io.{ByteArrayInputStream, OutputStream} + +import ammonite.compiler.iface.{ + Compiler => ICompiler, + CompilerBuilder => ICompilerBuilder, + CompilerLifecycleManager => ICompilerLifecycleManager, + Preprocessor => IPreprocessor, + _ +} +import ammonite.compiler.internal.CompilerHelper +import ammonite.util.{ImportData, Imports, PositionOffsetConversion, Printer} +import ammonite.util.Util.newLine + +import dotty.tools.dotc +import dotc.{CompilationUnit, Compiler => DottyCompiler, Run, ScalacCommand} +import dotc.ast.{tpd, untpd} +import dotc.ast.Positioned +import dotc.classpath +import dotc.config.{CompilerCommand, JavaPlatform} +import dotc.core.Contexts._ +import dotc.core.{Flags, MacroClassLoader, Mode} +import dotc.core.Comments.{ContextDoc, ContextDocstrings} +import dotc.core.Phases.{Phase, unfusedPhases} +import dotc.core.Symbols.{defn, Symbol} +import dotc.fromtasty.TastyFileUtil +import dotc.interactive.Completion +import dotc.report +import dotc.reporting +import dotc.semanticdb +import dotc.transform.{PostTyper, Staging} +import dotc.util.{Property, SourceFile, SourcePosition} +import dotc.util.Spans.Span +import dotty.tools.io.{ + AbstractFile, + ClassPath, + ClassRepresentation, + File, + VirtualDirectory, + VirtualFile, + PlainFile +} +import dotty.tools.repl.CollectTopLevelImports + +class Compiler( + dynamicClassPath: AbstractFile, + initialClassPath: Seq[URL], + classPath: Seq[URL], + macroClassLoader: ClassLoader, + whiteList: Set[Seq[String]], + dependencyCompleteOpt: => Option[String => (Int, Seq[String])] = None, + contextInit: FreshContext => Unit = _ => (), + settings: Seq[String] = Nil, + reporter: Option[ICompilerBuilder.Message => Unit] = None +) extends ICompiler: + self => + + import Compiler.{enumerateVdFiles, files} + + private val outputDir = new VirtualDirectory("(memory)") + + private def initCtx: Context = + val base: ContextBase = + new ContextBase: + override protected def newPlatform(using Context) = + new JavaPlatform: + private var classPath0: ClassPath = null + override def classPath(using Context) = + if (classPath0 == null) + classPath0 = classpath.AggregateClassPath(Seq( + asDottyClassPath(initialClassPath, whiteListed = true), + asDottyClassPath(self.classPath), + classpath.ClassPathFactory.newClassPath(dynamicClassPath) + )) + classPath0 + base.initialCtx + + private def sourcesRequired = false + + private lazy val MacroClassLoaderKey = + val cls = macroClassLoader.loadClass("dotty.tools.dotc.core.MacroClassLoader$") + val fld = cls.getDeclaredField("MacroClassLoaderKey") + fld.setAccessible(true) + fld.get(null).asInstanceOf[Property.Key[ClassLoader]] + + // Originally adapted from + // https://github.com/lampepfl/dotty/blob/3.0.0-M3/ + // compiler/src/dotty/tools/dotc/Driver.scala/#L67-L81 + private def setup(args: Array[String], rootCtx: Context): (List[String], Context) = + given ictx: FreshContext = rootCtx.fresh + val summary = ScalacCommand.distill(args, ictx.settings)(ictx.settingsState)(using ictx) + ictx.setSettings(summary.sstate) + ictx.setProperty(MacroClassLoaderKey, macroClassLoader) + Positioned.init + + if !ictx.settings.YdropComments.value then + ictx.setProperty(ContextDoc, new ContextDocstrings) + val fileNamesOpt = ScalacCommand.checkUsage( + summary, + sourcesRequired + )(using ictx.settings)(using ictx.settingsState) + val fileNames = fileNamesOpt.getOrElse { + throw new Exception("Error initializing compiler") + } + contextInit(ictx) + (fileNames, ictx) + + private def asDottyClassPath( + cp: Seq[URL], + whiteListed: Boolean = false + )(using Context): ClassPath = + val (dirs, jars) = cp.partition { url => + url.getProtocol == "file" && Files.isDirectory(Paths.get(url.toURI)) + } + + val dirsCp = dirs.map(u => classpath.ClassPathFactory.newClassPath(AbstractFile.getURL(u))) + val jarsCp = jars + .filter(ammonite.util.Classpath.canBeOpenedAsJar) + .map(u => classpath.ZipAndJarClassPathFactory.create(AbstractFile.getURL(u))) + + if (whiteListed) new dotty.ammonite.compiler.WhiteListClasspath(dirsCp ++ jarsCp, whiteList) + else classpath.AggregateClassPath(dirsCp ++ jarsCp) + + // Originally adapted from + // https://github.com/lampepfl/dotty/blob/3.0.0-M3/ + // compiler/src/dotty/tools/repl/ReplDriver.scala/#L67-L73 + /** Create a fresh and initialized context with IDE mode enabled */ + lazy val initialCtx = + val rootCtx = initCtx.fresh.addMode(Mode.ReadPositions | Mode.Interactive) + rootCtx.setSetting(rootCtx.settings.YcookComments, true) + // FIXME Disabled for the tests to pass + rootCtx.setSetting(rootCtx.settings.color, "never") + // FIXME We lose possible custom openStream implementations on the URLs of initialClassPath and + // classPath + val initialClassPath0 = initialClassPath + // .filter(!_.toURI.toASCIIString.contains("fansi_2.13")) + // .filter(!_.toURI.toASCIIString.contains("pprint_2.13")) + rootCtx.setSetting(rootCtx.settings.outputDir, outputDir) + + val (_, ictx) = setup(settings.toArray, rootCtx) + ictx.base.initialize()(using ictx) + ictx + + private var userCodeNestingLevel = -1 + + // Originally adapted from + // https://github.com/lampepfl/dotty/blob/3.0.0-M3/ + // compiler/src/dotty/tools/repl/ReplCompiler.scala/#L34-L39 + val compiler = + new DottyCompiler: + override protected def frontendPhases: List[List[Phase]] = + CompilerHelper.frontEndPhases ++ + List( + List(new semanticdb.ExtractSemanticDB), + List(new AmmonitePhase(userCodeNestingLevel, userCodeNestingLevel == 2)), + List(new PostTyper) + ) + + // Originally adapted from + // https://github.com/lampepfl/dotty/blob/3.0.0-M3/ + // compiler/src/dotty/tools/repl/Rendering.scala/#L97-L103 + /** Formats errors using the `messageRenderer` */ + private def formatError(dia: reporting.Diagnostic)(implicit ctx: Context): reporting.Diagnostic = + val renderedMessage = CompilerHelper.messageAndPos(Compiler.messageRenderer, dia) + new reporting.Diagnostic( + reporting.NoExplanation(renderedMessage), + dia.pos, + dia.level + ) + + def compile( + src: Array[Byte], + printer: Printer, + importsLen: Int, + userCodeNestingLevel: Int, + fileName: String + ): Option[ICompiler.Output] = + // println(s"Compiling\n${new String(src, StandardCharsets.UTF_8)}\n") + + self.userCodeNestingLevel = userCodeNestingLevel + + val reporter0 = reporter match { + case None => + Compiler.newStoreReporter() + case Some(rep) => + val simpleReporter = new dotc.interfaces.SimpleReporter { + def report(diag: dotc.interfaces.Diagnostic) = { + val severity = diag.level match { + case dotc.interfaces.Diagnostic.ERROR => "ERROR" + case dotc.interfaces.Diagnostic.WARNING => "WARNING" + case dotc.interfaces.Diagnostic.INFO => "INFO" + case _ => "INFO" // should not happen + } + val pos = Some(diag.position).filter(_.isPresent).map(_.get) + val start = pos.fold(0)(_.start) + val end = pos.fold(new String(src, "UTF-8").length)(_.end) + val msg = ICompilerBuilder.Message(severity, start, end, diag.message) + rep(msg) + } + } + reporting.Reporter.fromSimpleReporter(simpleReporter) + } + val run = new Run(compiler, initialCtx.fresh.setReporter(reporter0)) + + val semanticDbEnabled = run.runContext.settings.Xsemanticdb.value(using run.runContext) + val sourceFile = + if (semanticDbEnabled) { + // semanticdb needs the sources to be written on disk, so we assume they're there already + val root = run.runContext.settings.sourceroot.value(using run.runContext) + SourceFile(AbstractFile.getFile(Paths.get(root).resolve(fileName)), "UTF-8") + }else{ + val vf = new VirtualFile(fileName.split("/", -1).last, fileName) + val out = vf.output + out.write(src) + out.close() + new SourceFile(vf, new String(src, "UTF-8").toCharArray) + } + + implicit val ctx: Context = run.runContext.withSource(sourceFile) + + val unit = + new CompilationUnit(ctx.source): + // as done in + // https://github.com/lampepfl/dotty/blob/3.0.0-M3/ + // compiler/src/dotty/tools/repl/ReplCompillationUnit.scala/#L8 + override def isSuspendable: Boolean = false + ctx + .run + .compileUnits(unit :: Nil) + + val result = + if (ctx.reporter.hasErrors) Left(reporter.fold(ctx.reporter.removeBufferedMessages)(_ => Nil)) + else Right((reporter.fold(ctx.reporter.removeBufferedMessages)(_ => Nil), unit)) + + def formatDiagnostics(diagnostics: List[reporting.Diagnostic]): List[String] = { + val scalaPosToScPos = PositionOffsetConversion.scalaPosToScPos( + new String(src).drop(importsLen), + 0, + 0, + new String(src), + importsLen + ) + val scFile = new SourceFile(sourceFile.file, sourceFile.content().drop(importsLen)) + def scalaOffsetToScOffset(scalaOffset: Int): Option[Int] = + scalaPosToScPos(sourceFile.offsetToLine(scalaOffset), sourceFile.column(scalaOffset)).map { + case (scLine, scCol) => scFile.lineToOffset(scLine) + scCol + } + def scalaSpanToScSpan(scalaSpan: Span): Option[Span] = + for { + scStart <- scalaOffsetToScOffset(scalaSpan.start) + scEnd <- scalaOffsetToScOffset(scalaSpan.end) + scPoint <- scalaOffsetToScOffset(scalaSpan.point) + } yield Span(scStart, scEnd, scPoint) + def scalaSourcePosToScSourcePos(sourcePos: SourcePosition): Option[SourcePosition] = + if (sourcePos.source == sourceFile) + scalaSpanToScSpan(sourcePos.span).map { scSpan => + SourcePosition(scFile, scSpan, sourcePos.outer) + } + else + None + def scalaDiagnosticToScDiagnostic(diag: reporting.Diagnostic): Option[reporting.Diagnostic] = + scalaSourcePosToScSourcePos(diag.pos).map { scPos => + new reporting.Diagnostic(diag.msg, scPos, diag.level) + } + + diagnostics + .map(d => scalaDiagnosticToScDiagnostic(d).getOrElse(d)) + .map(formatError) + .map(_.msg.toString) + } + + result match { + case Left(errors) => + for (err <- formatDiagnostics(errors)) + printer.error(err) + None + case Right((warnings, unit)) => + for (warn <- formatDiagnostics(warnings)) + printer.warning(warn) + val newImports = unfusedPhases.collectFirst { + case p: AmmonitePhase => p.importData + }.getOrElse(Seq.empty[ImportData]) + val usedEarlierDefinitions = unfusedPhases.collectFirst { + case p: AmmonitePhase => p.usedEarlierDefinitions + }.getOrElse(Seq.empty[String]) + val fileCount = enumerateVdFiles(outputDir).length + val classes = files(outputDir).toArray + // outputDir is None here, dynamicClassPath should already correspond to an on-disk directory + Compiler.addToClasspath(classes, dynamicClassPath, None) + outputDir.clear() + val lineShift = PositionOffsetConversion.offsetToPos(new String(src)).apply(importsLen).line + val mappings = Map(sourceFile.file.name -> (sourceFile.file.name, -lineShift)) + val postProcessedClasses = classes.toVector.map { + case (path, byteCode) if path.endsWith(".class") => + val updatedByteCodeOpt = AsmPositionUpdater.postProcess( + mappings, + new ByteArrayInputStream(byteCode) + ) + (path, updatedByteCodeOpt.getOrElse(byteCode)) + case other => + other + } + val output = ICompiler.Output( + postProcessedClasses, + Imports(newImports), + Some(usedEarlierDefinitions) + ) + Some(output) + } + + def objCompiler = compiler + + def preprocessor(fileName: String, markGeneratedSections: Boolean): IPreprocessor = + new Preprocessor( + initialCtx.fresh.withSource(SourceFile.virtual(fileName, "")), + markGeneratedSections: Boolean + ) + + // Originally adapted from + // https://github.com/lampepfl/dotty/blob/3.0.0-M3/ + // compiler/src/dotty/tools/repl/ReplCompiler.scala/#L224-L286 + def tryTypeCheck( + src: Array[Byte], + fileName: String + ) = + val sourceFile = SourceFile.virtual(fileName, new String(src, StandardCharsets.UTF_8)) + + val reporter0 = Compiler.newStoreReporter() + val run = new Run( + compiler, + initialCtx.fresh + .addMode(Mode.ReadPositions | Mode.Interactive) + .setReporter(reporter0) + .setSetting(initialCtx.settings.YstopAfter, List("typer")) + ) + implicit val ctx: Context = run.runContext.withSource(sourceFile) + + val unit = + new CompilationUnit(ctx.source): + override def isSuspendable: Boolean = false + ctx + .run + .compileUnits(unit :: Nil, ctx) + + (unit.tpdTree, ctx) + + def complete( + offset: Int, + previousImports: String, + snippet: String + ): (Int, Seq[String], Seq[String]) = { + + val prefix = previousImports + newLine + + "object AutocompleteWrapper{ val expr: _root_.scala.Unit = {" + newLine + val suffix = newLine + "()}}" + val allCode = prefix + snippet + suffix + val index = offset + prefix.length + + + // Originally based on + // https://github.com/lampepfl/dotty/blob/3.0.0-M1/ + // compiler/src/dotty/tools/repl/ReplDriver.scala/#L179-L191 + + val (tree, ctx0) = tryTypeCheck(allCode.getBytes("UTF-8"), "") + val ctx = ctx0.fresh + val file = SourceFile.virtual("", allCode, maybeIncomplete = true) + val unit = CompilationUnit(file)(using ctx) + unit.tpdTree = { + given Context = ctx + import tpd._ + tree match { + case PackageDef(_, p) => + p.collectFirst { + case TypeDef(_, tmpl: Template) => + tmpl.body + .collectFirst { case dd: ValDef if dd.name.show == "expr" => dd } + .getOrElse(???) + }.getOrElse(???) + case _ => ??? + } + } + val ctx1 = ctx.fresh.setCompilationUnit(unit) + val srcPos = SourcePosition(file, Span(index)) + val (start, completions) = dotty.ammonite.compiler.AmmCompletion.completions( + srcPos, + dependencyCompleteOpt = dependencyCompleteOpt, + enableDeep = false + )(using ctx1) + + val blacklistedPackages = Set("shaded") + + def deepCompletion(name: String): List[String] = { + given Context = ctx1 + def rec(t: Symbol): Seq[Symbol] = { + if (blacklistedPackages(t.name.toString)) + Nil + else { + val children = + if (t.is(Flags.Package) || t.is(Flags.PackageVal) || t.is(Flags.PackageClass)) + t.denot.info.allMembers.map(_.symbol).filter(_ != t).flatMap(rec) + else Nil + + t +: children.toSeq + } + } + + for { + member <- defn.RootClass.denot.info.allMembers.map(_.symbol).toList + sym <- rec(member) + // Scala 2 comment: sketchy name munging because I don't know how to do this properly + // Note lack of back-quoting support. + strippedName = sym.name.toString.stripPrefix("package$").stripSuffix("$") + if strippedName.startsWith(name) + (pref, _) = sym.fullName.toString.splitAt(sym.fullName.toString.lastIndexOf('.') + 1) + out = pref + strippedName + if out != "" + } yield out + } + + def blacklisted(s: Symbol) = { + given Context = ctx1 + val blacklist = Set( + "scala.Predef.any2stringadd.+", + "scala.Any.##", + "java.lang.Object.##", + "scala.", + "scala.", + "scala.", + "scala.", + "scala.Predef.StringFormat.formatted", + "scala.Predef.Ensuring.ensuring", + "scala.Predef.ArrowAssoc.->", + "scala.Predef.ArrowAssoc.→", + "java.lang.Object.synchronized", + "java.lang.Object.ne", + "java.lang.Object.eq", + "java.lang.Object.wait", + "java.lang.Object.notifyAll", + "java.lang.Object.notify", + "java.lang.Object.clone", + "java.lang.Object.finalize" + ) + + blacklist(s.fullName.toString) || + s.isOneOf(Flags.GivenOrImplicit) || + // Cache objects, which you should probably never need to + // access directly, and apart from that have annoyingly long names + "cache[a-f0-9]{32}".r.findPrefixMatchOf(s.name.decode.toString).isDefined || + // s.isDeprecated || + s.name.decode.toString == "" || + s.name.decode.toString.contains('$') + } + + val filteredCompletions = completions.filter { c => + c.symbols.isEmpty || c.symbols.exists(!blacklisted(_)) + } + val signatures = { + given Context = ctx1 + for { + c <- filteredCompletions + s <- c.symbols + isMethod = s.denot.is(Flags.Method) + if isMethod + } yield s"def ${s.name}${s.denot.info.widenTermRefExpr.show}" + } + (start - prefix.length, filteredCompletions.map(_.label.replace(".package$.", ".")), signatures) + } + +object Compiler: + + /** Create empty outer store reporter */ + def newStoreReporter(): reporting.StoreReporter = + new reporting.StoreReporter(null) + with reporting.UniqueMessagePositions with reporting.HideNonSensicalMessages + + private def enumerateVdFiles(d: VirtualDirectory): Iterator[AbstractFile] = + val (subs, files) = d.iterator.partition(_.isDirectory) + files ++ subs.map(_.asInstanceOf[VirtualDirectory]).flatMap(enumerateVdFiles) + + private def files(d: VirtualDirectory): Iterator[(String, Array[Byte])] = + for (x <- enumerateVdFiles(d) if x.name.endsWith(".class") || x.name.endsWith(".tasty")) yield { + val segments = x.path.split("/").toList.tail + (x.path.stripPrefix("(memory)/"), x.toByteArray) + } + + private def writeDeep( + d: AbstractFile, + path: List[String] + ): OutputStream = path match { + case head :: Nil => d.fileNamed(path.head).output + case head :: rest => + writeDeep( + d.subdirectoryNamed(head), //.asInstanceOf[VirtualDirectory], + rest + ) + // We should never write to an empty path, and one of the above cases + // should catch this and return before getting here + case Nil => ??? + } + + def addToClasspath(classFiles: Traversable[(String, Array[Byte])], + dynamicClasspath: AbstractFile, + outputDir: Option[Path]): Unit = { + + val outputDir0 = outputDir.map(os.Path(_, os.pwd)) + for((name, bytes) <- classFiles){ + val elems = name.split('/').toList + val output = writeDeep(dynamicClasspath, elems) + output.write(bytes) + output.close() + + for (dir <- outputDir0) + os.write.over(dir / elems, bytes, createFolders = true) + } + + } + + private[compiler] val messageRenderer = + new reporting.MessageRendering {} diff --git a/amm/compiler/src/main/scala-3.3.2+/ammonite/compiler/Preprocessor.scala b/amm/compiler/src/main/scala-3.3.2+/ammonite/compiler/Preprocessor.scala new file mode 100644 index 000000000..273005258 --- /dev/null +++ b/amm/compiler/src/main/scala-3.3.2+/ammonite/compiler/Preprocessor.scala @@ -0,0 +1,332 @@ +package ammonite.compiler + +import java.util.function.{Function => JFunction} + +import ammonite.compiler.iface.{Compiler => _, Parser => _, Preprocessor => IPreprocessor, _} +import ammonite.util.{Imports, Name, Res} +import ammonite.util.Util.CodeSource +import pprint.Util + +import dotty.tools.dotc +import dotc.ast.desugar +import dotc.ast.untpd +import dotc.core.Contexts._ +import dotc.core.{Flags, Names} +import dotc.parsing.Parsers.Parser +import dotc.parsing.Tokens +import dotc.util.SourceFile + +class Preprocessor( + ctx: Context, + markGeneratedSections: Boolean +) extends IPreprocessor { + + // FIXME Quite some duplication with DefaultProcessor for Scala 2.x + + private case class Expanded(code: String, printer: Seq[String]) + + private def parse(source: String): Either[Seq[String], List[untpd.Tree]] = { + val reporter = Compiler.newStoreReporter() + val sourceFile = SourceFile.virtual("foo", source) + val parseCtx = ctx.fresh.setReporter(reporter).withSource(sourceFile) + val parser = new DottyParser(sourceFile)(using parseCtx) + val stats = parser.blockStatSeq() + parser.accept(Tokens.EOF) + if (reporter.hasErrors) { + val errorsStr = reporter + .allErrors + // .map(rendering.formatError) + .map(e => scala.util.Try(e.msg.toString).toOption.getOrElse("???")) + Left(errorsStr) + } else + Right(stats) + } + + def transform( + stmts: Seq[String], + resultIndex: String, + leadingSpaces: String, + codeSource: CodeSource, + indexedWrapper: Name, + imports: Imports, + printerTemplate: String => String, + extraCode: String, + skipEmpty: Boolean, + markScript: Boolean, + codeWrapper: CodeWrapper + ): Res[IPreprocessor.Output] = { + + // println(s"transformOrNull(${stmts.toSeq})") + + // All code Ammonite compiles must be rooted in some package within + // the `ammonite` top-level package + assert(codeSource.pkgName.head == Name("ammonite")) + + expandStatements(stmts, resultIndex, skipEmpty).map { + case Expanded(code, printer) => + val (wrappedCode, importsLength, userCodeNestingLevel) = wrapCode( + codeSource, indexedWrapper, leadingSpaces + code, + printerTemplate(printer.mkString(", ")), + imports, extraCode, markScript, codeWrapper + ) + IPreprocessor.Output(wrappedCode, importsLength, userCodeNestingLevel) + } + } + + private def expandStatements( + stmts: Seq[String], + wrapperIndex: String, + skipEmpty: Boolean + ): Res[Expanded] = + stmts match{ + // In the REPL, we do not process empty inputs at all, to avoid + // unnecessarily incrementing the command counter + // + // But in scripts, we process empty inputs and create an empty object, + // to ensure that when the time comes to cache/load the class it exists + case Nil if skipEmpty => Res.Skip + case postSplit => + Res(complete(stmts.mkString(""), wrapperIndex, postSplit)) + + } + + private def wrapCode( + codeSource: CodeSource, + indexedWrapperName: Name, + code: String, + printCode: String, + imports: Imports, + extraCode: String, + markScript: Boolean, + codeWrapper: CodeWrapper + ) = { + + //we need to normalize topWrapper and bottomWrapper in order to ensure + //the snippets always use the platform-specific newLine + val extraCode0 = + if (markScript) extraCode + "/**/" + else extraCode + val (topWrapper, bottomWrapper, userCodeNestingLevel) = + codeWrapper(code, codeSource, imports, printCode, indexedWrapperName, extraCode0) + val (topWrapper0, bottomWrapper0) = + if (markScript) (topWrapper + "/**/ /**/" + bottomWrapper) + else (topWrapper, bottomWrapper) + val importsLen = topWrapper0.length + + (topWrapper0 + code + bottomWrapper0, importsLen, userCodeNestingLevel) + } + + // Large parts of the logic below is adapted from DefaultProcessor, + // the Scala 2 counterpart of this file. + + private def isPrivate(tree: untpd.Tree): Boolean = + tree match { + case m: untpd.MemberDef => m.mods.is(Flags.Private) + case _ => false + } + + private def Processor(cond: PartialFunction[(String, String, untpd.Tree), Expanded]) = + (code: String, name: String, tree: untpd.Tree) => cond.lift(name, code, tree) + + private def pprintSignature(ident: String, customMsg: Option[String]): String = + val customCode = customMsg.fold("_root_.scala.None")(x => s"""_root_.scala.Some("$x")""") + s""" + _root_.ammonite + .repl + .ReplBridge + .value + .Internal + .print($ident, ${Util.literalize(ident)}, $customCode) + """ + private def definedStr(definitionLabel: String, name: String) = + s""" + _root_.ammonite + .repl + .ReplBridge + .value + .Internal + .printDef("$definitionLabel", ${Util.literalize(name)}) + """ + private def pprint(ident: String) = pprintSignature(ident, None) + + /** + * Processors for declarations which all have the same shape + */ + private def DefProc(definitionLabel: String)(cond: PartialFunction[untpd.Tree, Names.Name]) = + (code: String, name: String, tree: untpd.Tree) => + cond.lift(tree).map{ name => + val printer = + if (isPrivate(tree)) Nil + else + val definedName = + if name.isEmpty then "" + else Name.backtickWrap(name.decode.toString) + Seq(definedStr(definitionLabel, definedName)) + Expanded( + code, + printer + ) + } + + private val ObjectDef = DefProc("object"){case m: untpd.ModuleDef => m.name} + private val ClassDef = DefProc("class"){ + case m: untpd.TypeDef if m.isClassDef && !m.mods.flags.is(Flags.Trait) => + m.name + } + private val TraitDef = DefProc("trait"){ + case m: untpd.TypeDef if m.isClassDef && m.mods.flags.is(Flags.Trait) => + m.name + } + private val DefDef = DefProc("function"){ + case m: untpd.DefDef if m.mods.flags.is(Flags.Given) && m.name.isEmpty => + given Context = ctx + desugar.inventGivenOrExtensionName(m.tpt) + case m: untpd.DefDef => + m.name + } + + private val ExtDef = DefProc("extension methods") { + case ext: untpd.ExtMethods => Names.EmptyTermName + } + private val TypeDef = DefProc("type"){ case m: untpd.TypeDef => m.name } + + private val VarDef = Processor { case (name, code, t: untpd.ValDef) => + Expanded( + //Only wrap rhs in function if it is not a function + //Wrapping functions causes type inference errors. + code, + // Try to leave out all synthetics; we don't actually have proper + // synthetic flags right now, because we're dumb-parsing it and not putting + // it through a full compilation + if (isPrivate(t) || t.name.decode.toString.contains("$")) Nil + else if (t.mods.flags.is(Flags.Given)) { + given Context = ctx + val name0 = if (t.name.isEmpty) desugar.inventGivenOrExtensionName(t.tpt) else t.name + Seq(pprintSignature(Name.backtickWrap(name0.decode.toString), Some(""))) + } + else if (t.mods.flags.is(Flags.Lazy)) + Seq(pprintSignature(Name.backtickWrap(t.name.decode.toString), Some(""))) + else Seq(pprint(Name.backtickWrap(t.name.decode.toString))) + ) + } + + private val PatDef = Processor { case (name, code, t: untpd.PatDef) => + val isLazy = t.mods.flags.is(Flags.Lazy) + val printers = + if (isPrivate(t)) Nil + else + t.pats + .flatMap { + case untpd.Tuple(trees) => trees + case elem => List(elem) + } + .flatMap { + case untpd.Ident(name) => + val decoded = name.decode.toString + if (decoded.contains("$")) Nil + else if (isLazy) Seq(pprintSignature(Name.backtickWrap(decoded), Some(""))) + else Seq(pprint(Name.backtickWrap(decoded))) + case _ => Nil // can this happen? + } + Expanded(code, printers) + } + + private val Import = Processor { + case (name, code, tree: untpd.Import) => + val Array(keyword, body) = code.split(" ", 2) + val tq = "\"\"\"" + Expanded(code, Seq( + s""" + _root_.ammonite + .repl + .ReplBridge + .value + .Internal + .printImport(${Util.literalize(body)}) + """ + )) + } + + private val Expr = Processor { + //Expressions are lifted to anon function applications so they will be JITed + case (name, code, tree) => + val expandedCode = + if (markGeneratedSections) + s"/**/val $name = /**/$code" + else + s"val $name = $code" + Expanded( + expandedCode, + if (isPrivate(tree)) Nil else Seq(pprint(name)) + ) + } + + private val decls = Seq[(String, String, untpd.Tree) => Option[Expanded]]( + ObjectDef, ClassDef, TraitDef, DefDef, ExtDef, TypeDef, VarDef, PatDef, Import, Expr + ) + + private def complete( + code: String, + resultIndex: String, + postSplit: Seq[String] + ): Either[String, Expanded] = { + val reParsed = postSplit.map(p => (parse(p), p)) + val errors = reParsed.collect{case (Left(e), _) => e }.flatten + if (errors.length != 0) Left(errors.mkString(System.lineSeparator())) + else { + val allDecls = for { + ((Right(trees), code), i) <- reParsed.zipWithIndex if trees.nonEmpty + } yield { + // Suffix the name of the result variable with the index of + // the tree if there is more than one statement in this command + val suffix = if (reParsed.length > 1) "_" + i else "" + def handleTree(t: untpd.Tree) = { + // println(s"handleTree($t)") + val it = decls.iterator.flatMap(_.apply(code, "res" + resultIndex + suffix, t)) + if (it.hasNext) + it.next() + else { + sys.error(s"Don't know how to handle ${t.getClass}: $t") + } + } + trees match { + case Seq(tree) => handleTree(tree) + + // This handles the multi-import case `import a.b, c.d` + case trees if trees.forall(_.isInstanceOf[untpd.Import]) => handleTree(trees(0)) + + // AFAIK this can only happen for pattern-matching multi-assignment, + // which for some reason parse into a list of statements. In such a + // scenario, aggregate all their printers, but only output the code once + case trees => + val printers = for { + tree <- trees + if tree.isInstanceOf[untpd.ValDef] + Expanded(_, printers) = handleTree(tree) + printer <- printers + } yield printer + + Expanded(code, printers) + } + } + + val expanded = allDecls match{ + case Seq(first, rest@_*) => + val allDeclsWithComments = Expanded(first.code, first.printer) +: rest + allDeclsWithComments.reduce { (a, b) => + Expanded( + // We do not need to separate the code with our own semi-colons + // or newlines, as each expanded code snippet itself comes with + // it's own trailing newline/semicolons as a result of the + // initial split + a.code + b.code, + a.printer ++ b.printer + ) + } + case Nil => Expanded("", Nil) + } + + Right(expanded) + } + } +} diff --git a/amm/compiler/src/main/scala-3.3.2+/ammonite/compiler/SyntaxHighlighting.scala b/amm/compiler/src/main/scala-3.3.2+/ammonite/compiler/SyntaxHighlighting.scala new file mode 100644 index 000000000..0de0e7ed4 --- /dev/null +++ b/amm/compiler/src/main/scala-3.3.2+/ammonite/compiler/SyntaxHighlighting.scala @@ -0,0 +1,132 @@ +package ammonite.compiler + +// Originally adapted from +// https://github.com/lampepfl/dotty/blob/3.0.0-M3/ +// compiler/src/dotty/tools/dotc/printing/SyntaxHighlighting.scala + +import dotty.tools.dotc +import dotc.CompilationUnit +import dotc.ast.untpd +import dotc.core.Contexts._ +import dotc.core.StdNames._ +import dotc.parsing.Parsers.Parser +import dotc.parsing.Scanners.Scanner +import dotc.parsing.Tokens._ +import dotc.reporting.Reporter +import dotc.util.Spans.Span +import dotc.util.SourceFile + +import java.util.Arrays + +/** This object provides functions for syntax highlighting in the REPL */ +class SyntaxHighlighting( + noAttrs: fansi.Attrs, + commentAttrs: fansi.Attrs, + keywordAttrs: fansi.Attrs, + valDefAttrs: fansi.Attrs, + literalAttrs: fansi.Attrs, + typeAttrs: fansi.Attrs, + annotationAttrs: fansi.Attrs, + notImplementedAttrs: fansi.Attrs + ) { + + def highlight(in: String)(using Context): String = { + def freshCtx = ctx.fresh.setReporter(Reporter.NoReporter) + if (in.isEmpty || ctx.settings.color.value == "never") in + else { + val source = SourceFile.virtual("", in) + + given Context = freshCtx + .setCompilationUnit(CompilationUnit(source, mustExist = false)(using freshCtx)) + + val colors = Array.fill(in.length)(0L) + + def highlightRange(from: Int, to: Int, attr: fansi.Attrs) = + Arrays.fill(colors, from, to, attr.applyMask) + + def highlightPosition(span: Span, attr: fansi.Attrs) = + if (span.exists && span.start >= 0 && span.end <= in.length) + highlightRange(span.start, span.end, attr) + + val scanner = new Scanner(source) + while (scanner.token != EOF) { + val start = scanner.offset + val token = scanner.token + val name = scanner.name + val isSoftModifier = scanner.isSoftModifierInModifierPosition + scanner.nextToken() + val end = scanner.lastOffset + + // Branch order is important. For example, + // `true` is at the same time a keyword and a literal + token match { + case _ if literalTokens.contains(token) => + highlightRange(start, end, literalAttrs) + + case STRINGPART => + // String interpolation parts include `$` but + // we don't highlight it, hence the `-1` + highlightRange(start, end - 1, literalAttrs) + + case _ if alphaKeywords.contains(token) || isSoftModifier => + highlightRange(start, end, keywordAttrs) + + case IDENTIFIER if name == nme.??? => + highlightRange(start, end, notImplementedAttrs) + + case _ => + } + } + + for { + comment <- scanner.comments + span = comment.span + } highlightPosition(span, commentAttrs) + + object TreeHighlighter extends untpd.UntypedTreeTraverser { + import untpd._ + + def ignored(tree: NameTree) = { + val name = tree.name.toTermName + // trees named and have weird positions + name == nme.ERROR || name == nme.CONSTRUCTOR + } + + def highlightAnnotations(tree: MemberDef): Unit = + for (annotation <- tree.mods.annotations) + highlightPosition(annotation.span, annotationAttrs) + + def highlight(trees: List[Tree])(using Context): Unit = + trees.foreach(traverse) + + def traverse(tree: Tree)(using Context): Unit = { + tree match { + case tree: NameTree if ignored(tree) => + () + case tree: ValOrDefDef => + highlightAnnotations(tree) + highlightPosition(tree.nameSpan, valDefAttrs) + case tree: MemberDef /* ModuleDef | TypeDef */ => + highlightAnnotations(tree) + highlightPosition(tree.nameSpan, typeAttrs) + case tree: Ident if tree.isType => + highlightPosition(tree.span, typeAttrs) + case _: TypTree => + highlightPosition(tree.span, typeAttrs) + case _ => + } + traverseChildren(tree) + } + } + + val parser = new DottyParser(source) + val trees = parser.blockStatSeq() + TreeHighlighter.highlight(trees) + + // if (colorAt.last != NoColor) + // highlighted.append(NoColor) + + fansi.Str.fromArrays(in.toCharArray, colors).render + } + } +} diff --git a/build.sc b/build.sc index d888aa33c..04540f191 100644 --- a/build.sc +++ b/build.sc @@ -36,7 +36,7 @@ val commitsSinceTaggedVersion = { val scala2_12Versions = Seq("2.12.8", "2.12.9", "2.12.10", "2.12.11", "2.12.12", "2.12.13", "2.12.14", "2.12.15", "2.12.16", "2.12.17", "2.12.18") val scala2_13Versions = Seq("2.13.2", "2.13.3", "2.13.4", "2.13.5", "2.13.6", "2.13.7", "2.13.8", "2.13.9", "2.13.10", "2.13.11", "2.13.12") val scala32Versions = Seq("3.2.0", "3.2.1", "3.2.2") -val scala33Versions = Seq("3.3.0", "3.3.1") +val scala33Versions = Seq("3.3.0", "3.3.1", "3.3.2", "3.3.3") val scala3Versions = scala32Versions ++ scala33Versions val binCrossScalaVersions = Seq(scala2_12Versions.last, scala2_13Versions.last, scala32Versions.last) @@ -191,8 +191,14 @@ trait AmmInternalModule extends CrossSbtModule with Bloop.Module { if (sv.startsWith("2.13.") || sv.startsWith("3.")) Seq(PathRef(millSourcePath / "src" / "main" / "scala-2.13-or-3")) else Nil + val extraDir5 = + if (sv.startsWith("3.3") && sv.stripPrefix("3.3.").toInt >= 2) + Seq(PathRef(millSourcePath / "src" / "main" / "scala-3.3.2+")) + else if (sv.startsWith("3")) + Seq(PathRef(millSourcePath / "src" / "main" / "scala-3.0.0-3.3.1")) + else Nil - super.sources() ++ extraDir ++ extraDir2 ++ extraDir3 ++ extraDir4 + super.sources() ++ extraDir ++ extraDir2 ++ extraDir3 ++ extraDir4 ++ extraDir5 } def externalSources = T{ resolveDeps(allIvyDeps, sources = true)()