diff --git a/compiler/src/dotty/tools/repl/ReplDriver.scala b/compiler/src/dotty/tools/repl/ReplDriver.scala index ffa1b648446d..dd67b43ee0a9 100644 --- a/compiler/src/dotty/tools/repl/ReplDriver.scala +++ b/compiler/src/dotty/tools/repl/ReplDriver.scala @@ -168,9 +168,12 @@ class ReplDriver(settings: Array[String], * observable outside of the CLI, for this reason, most helper methods are * `protected final` to facilitate testing. */ - def runUntilQuit(using initialState: State = initialState)(): State = { + def runUntilQuit(using initialState: State = initialState)(hardcodedInput: java.io.InputStream = null): State = { val terminal = new JLineTerminal + val hardcodedInputLines = + if (hardcodedInput == null) null + else new java.io.BufferedReader(new java.io.InputStreamReader(hardcodedInput)) out.println( s"""Welcome to Scala $simpleVersionString ($javaVersion, Java $javaVmName). |Type in expressions for evaluation. Or try :help.""".stripMargin) @@ -208,8 +211,12 @@ class ReplDriver(settings: Array[String], } try { - val line = terminal.readLine(completer) - ParseResult(line) + val line = + if (hardcodedInputLines != null) hardcodedInputLines.readLine() + else terminal.readLine(completer) + + if (line == null) Quit + else ParseResult(line) } catch { case _: EndOfFileException => // Ctrl+D Quit diff --git a/compiler/src/dotty/tools/repl/ReplMain.scala b/compiler/src/dotty/tools/repl/ReplMain.scala new file mode 100644 index 000000000000..9d93ee9a6be4 --- /dev/null +++ b/compiler/src/dotty/tools/repl/ReplMain.scala @@ -0,0 +1,60 @@ +package dotty.tools.repl + +import java.io.PrintStream + +class ReplMain( + settings: Array[String] = Array.empty, + out: PrintStream = Console.out, + classLoader: Option[ClassLoader] = Some(getClass.getClassLoader), + predefCode: String = "", + testCode: String = "" +): + def run(bindings: ReplMain.Bind[_]*): Any = + try + ReplMain.currentBindings.set(bindings.map{bind => bind.name -> bind.value}.toMap) + + val bindingsPredef = bindings + .map { case bind => + s"def ${bind.name}: ${bind.typeName.value} = dotty.tools.repl.ReplMain.currentBinding[${bind.typeName.value}](\"${bind.name}\")" + } + .mkString("\n") + + val fullPredef = + ReplDriver.pprintImport + + (if bindingsPredef.nonEmpty then s"\n$bindingsPredef\n" else "") + + (if predefCode.nonEmpty then s"\n$predefCode\n" else "") + + val driver = new ReplDriver(settings, out, classLoader, fullPredef) + + if (testCode == "") driver.tryRunning + else driver.runUntilQuit(using driver.initialState)( + new java.io.ByteArrayInputStream(testCode.getBytes()) + ) + () + finally + ReplMain.currentBindings.set(null) + + +object ReplMain: + final case class TypeName[A](value: String) + object TypeName extends TypeNamePlatform + + import scala.quoted._ + + trait TypeNamePlatform: + inline given [A]: TypeName[A] = ${TypeNamePlatform.impl[A]} + + object TypeNamePlatform: + def impl[A](using t: Type[A], ctx: Quotes): Expr[TypeName[A]] = + '{TypeName[A](${Expr(Type.show[A])})} + + + case class Bind[T](name: String, value: () => T)(implicit val typeName: TypeName[T]) + object Bind: + implicit def ammoniteReplArrowBinder[T](t: (String, T))(implicit typeName: TypeName[T]): Bind[T] = { + Bind(t._1, () => t._2)(typeName) + } + + def currentBinding[T](s: String): T = currentBindings.get().apply(s).apply().asInstanceOf[T] + + private val currentBindings = new ThreadLocal[Map[String, () => Any]]() diff --git a/compiler/test/dotty/tools/repl/ReplMainTest.scala b/compiler/test/dotty/tools/repl/ReplMainTest.scala new file mode 100644 index 000000000000..495d688f035d --- /dev/null +++ b/compiler/test/dotty/tools/repl/ReplMainTest.scala @@ -0,0 +1,73 @@ +package dotty.tools +package repl + +import scala.language.unsafeNulls + +import java.io.{ByteArrayOutputStream, PrintStream} +import java.nio.charset.StandardCharsets + +import vulpix.TestConfiguration +import org.junit.Test +import org.junit.Assert._ + +/** Tests for the programmatic REPL API (ReplMain) */ +class ReplMainTest: + + private val defaultOptions = Array("-classpath", TestConfiguration.withCompilerClasspath) + + private def captureOutput(body: PrintStream => Unit): String = + val out = new ByteArrayOutputStream() + val ps = new PrintStream(out, true, StandardCharsets.UTF_8.name) + body(ps) + dotty.shaded.fansi.Str(out.toString(StandardCharsets.UTF_8.name)).plainText + + @Test def basicBinding(): Unit = + val output = captureOutput { out => + val replMain = new ReplMain( + settings = defaultOptions, + out = out, + testCode = "test" + ) + + replMain.run("test" -> 42) + } + + assertTrue(output.contains("val res0: Int = 42")) + + @Test def multipleBindings(): Unit = + val output = captureOutput { out => + val replMain = new ReplMain( + settings = defaultOptions, + out = out, + testCode = "x\ny\nz" + ) + + replMain.run( + "x" -> 1, + "y" -> "hello", + "z" -> true + ) + } + + assertTrue(output.contains("val res0: Int = 1")) + assertTrue(output.contains("val res1: String = \"hello\"")) + assertTrue(output.contains("val res2: Boolean = true")) + + @Test def bindingTypes(): Unit = + val output = captureOutput { out => + val replMain = new ReplMain( + settings = defaultOptions ++ Array("-repl-quit-after-init"), + out = out, + testCode = "list\nmap" + ) + + replMain.run( + "list" -> List(1, 2, 3), + "map" -> Map(1 -> "hello") + ) + } + + assertTrue(output.contains("val res0: List[Int] = List(1, 2, 3)")) + assertTrue(output.contains("val res1: Map[Int, String] = Map(1 -> \"hello\")")) + +end ReplMainTest diff --git a/sbt-bridge/src/xsbt/ConsoleInterface.java b/sbt-bridge/src/xsbt/ConsoleInterface.java index 3ba4e011c8e3..2f9ac33098d5 100644 --- a/sbt-bridge/src/xsbt/ConsoleInterface.java +++ b/sbt-bridge/src/xsbt/ConsoleInterface.java @@ -49,7 +49,7 @@ public void run( state = driver.run(initialCommands, state); // TODO handle failure during initialisation - state = driver.runUntilQuit(state); + state = driver.runUntilQuit(state, null); driver.run(cleanupCommands, state); } }