Skip to content

Commit 4aa008c

Browse files
authored
fix 'runBefore' in server mode (#180)
root issue of joernio/joern#4999
1 parent 046326b commit 4aa008c

File tree

7 files changed

+47
-16
lines changed

7 files changed

+47
-16
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,9 @@ curl http://localhost:8080/query-sync -X POST -d '{"query": "val bar = foo + 1"}
451451

452452
curl http://localhost:8080/query-sync -X POST -d '{"query":"println(\"OMG remote code execution!!1!\")"}'
453453
# {"success":true,"stdout":"",...}%
454+
455+
# combine with `jq` to directly get the output, as if you had a local console:
456+
curl --silent http://localhost:8080/query-sync -X POST -d '{"query": "val baz = 43"}' | jq --raw-output .stdout
454457
```
455458

456459
The same for windows and powershell:

build.sbt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ ThisBuild / organization := "com.michaelpollmeier"
55
ThisBuild / scalaVersion := "3.4.3"
66

77
lazy val ScalaTestVersion = "3.2.18"
8+
lazy val Slf4jVersion = "2.0.16"
89

910
lazy val shadedLibs = project.in(file("shaded-libs"))
1011
.settings(
@@ -26,7 +27,7 @@ lazy val core = project.in(file("core"))
2627
executableScriptName := "srp",
2728
libraryDependencies ++= Seq(
2829
"org.scala-lang" %% "scala3-compiler" % scalaVersion.value,
29-
"org.slf4j" % "slf4j-simple" % "2.0.16" % Optional,
30+
"org.slf4j" % "slf4j-simple" % Slf4jVersion % Optional,
3031
),
3132
assemblyJarName := "srp.jar", // TODO remove the '.jar' suffix - when doing so, it doesn't work any longer
3233
)
@@ -43,6 +44,7 @@ lazy val server = project.in(file("server"))
4344
fork := true, // important: otherwise we run into classloader issues
4445
libraryDependencies ++= Seq(
4546
"com.lihaoyi" %% "cask" % "0.8.3",
47+
"org.slf4j" % "slf4j-simple" % Slf4jVersion % Optional,
4648
"com.lihaoyi" %% "requests" % "0.8.2" % Test,
4749
"org.scalatest" %% "scalatest" % ScalaTestVersion % "it",
4850
)

core/src/main/scala/replpp/Config.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ object Config {
153153

154154
def runBefore[C](using builder: OParserBuilder[C])(action: Action[String, C]) = {
155155
builder.opt[String]("runBefore")
156-
.valueName("'val foo = 42'")
156+
.valueName("\"import Int.MaxValue\"")
157157
.unbounded()
158158
.optional()
159159
.action(action)

server/src/it/scala/replpp/server/EmbeddedReplTests.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,19 @@ import scala.concurrent.duration.Duration
1313
class EmbeddedReplTests extends AnyWordSpec with Matchers {
1414

1515
"execute commands synchronously" in {
16-
val repl = new EmbeddedRepl(defaultCompilerArgs)
16+
val repl = new EmbeddedRepl(defaultCompilerArgs, runBeforeCode = Seq("import Short.MaxValue"))
1717

18-
repl.query("val x = 0").output.trim shouldBe "val x: Int = 0"
19-
repl.query("x + 1").output.trim shouldBe "val res0: Int = 1"
18+
repl.query("val x = MaxValue").output.trim shouldBe "val x: Short = 32767"
19+
repl.query("x + 1").output.trim shouldBe "val res0: Int = 32768"
2020

2121
repl.shutdown()
2222
}
2323

2424
"execute a command asynchronously" in {
25-
val repl = new EmbeddedRepl(defaultCompilerArgs)
26-
val (uuid, futureResult) = repl.queryAsync("val x = 0")
25+
val repl = new EmbeddedRepl(defaultCompilerArgs, runBeforeCode = Seq("import Short.MaxValue"))
26+
val (uuid, futureResult) = repl.queryAsync("val x = MaxValue")
2727
val result = Await.result(futureResult, Duration.Inf)
28-
result.trim shouldBe "val x: Int = 0"
28+
result.trim shouldBe "val x: Short = 32767"
2929
repl.shutdown()
3030
}
3131

server/src/main/scala/replpp/server/EmbeddedRepl.scala

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,20 @@ import java.util.concurrent.Executors
1212
import scala.concurrent.duration.Duration
1313
import scala.concurrent.{Await, ExecutionContext, ExecutionContextExecutorService, Future}
1414

15-
class EmbeddedRepl(compilerArgs: Array[String]) {
15+
class EmbeddedRepl(compilerArgs: Array[String], runBeforeCode: Seq[String]) {
1616
private val logger: Logger = LoggerFactory.getLogger(getClass)
1717

1818
/** repl and compiler output ends up in this replOutputStream */
1919
private val replOutputStream = new ByteArrayOutputStream()
2020

21-
private val replDriver: ReplDriver = {
22-
new ReplDriver(compilerArgs, new PrintStream(replOutputStream), classLoader = None)
23-
}
21+
private val replDriver = ReplDriver(compilerArgs, new PrintStream(replOutputStream), classLoader = None)
2422

25-
private var state: State = replDriver.initialState
23+
private var state: State = {
24+
if (runBeforeCode.nonEmpty)
25+
replDriver.execute(runBeforeCode)
26+
else
27+
replDriver.initialState
28+
}
2629

2730
private val singleThreadedJobExecutor: ExecutionContextExecutorService =
2831
ExecutionContext.fromExecutorService(Executors.newSingleThreadExecutor())

server/src/main/scala/replpp/server/ReplServer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ object ReplServer {
2323

2424
val baseConfig = precompilePredefFiles(serverConfig.baseConfig)
2525
val compilerArgs = replpp.compilerArgs(baseConfig)
26-
val embeddedRepl = new EmbeddedRepl(compilerArgs)
26+
val embeddedRepl = new EmbeddedRepl(compilerArgs, baseConfig.runBefore)
2727
Runtime.getRuntime.addShutdownHook(new Thread(() => {
2828
logger.info("Shutting down server...")
2929
embeddedRepl.shutdown()

server/src/test/scala/replpp/server/ReplServerTests.scala

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,29 @@ class ReplServerTests extends AnyWordSpec with Matchers {
123123
getResultResponse("stdout").str shouldBe "val bar: Int = 42\n"
124124
}
125125

126+
"use runBefore code" in Fixture(runBeforeCode = Seq("import Int.MaxValue")) { url =>
127+
val wsMsgPromise = scala.concurrent.Promise[String]()
128+
val connectedPromise = scala.concurrent.Promise[String]()
129+
cask.util.WsClient.connect(s"$url/connect") {
130+
case cask.Ws.Text(msg) if msg == "connected" =>
131+
connectedPromise.success(msg)
132+
case cask.Ws.Text(msg) =>
133+
wsMsgPromise.success(msg)
134+
}
135+
Await.result(connectedPromise.future, DefaultPromiseAwaitTimeout)
136+
val postQueryResponse = postQueryAsync(url, "val bar = MaxValue")
137+
val queryUUID = postQueryResponse("uuid").str
138+
queryUUID.length should not be 0
139+
140+
val queryResultWSMessage = Await.result(wsMsgPromise.future, DefaultPromiseAwaitTimeout)
141+
queryResultWSMessage.length should not be 0
142+
143+
val getResultResponse = getResponse(url, queryUUID)
144+
getResultResponse.obj.keySet should contain("success")
145+
getResultResponse("uuid").str shouldBe queryResultWSMessage
146+
getResultResponse("stdout").str shouldBe "val bar: Int = 2147483647\n"
147+
}
148+
126149
"disallow fetching the result of a completed query with an invalid auth header" in Fixture() { url =>
127150
val wsMsgPromise = scala.concurrent.Promise[String]()
128151
val connectedPromise = scala.concurrent.Promise[String]()
@@ -384,7 +407,7 @@ object Fixture {
384407
)
385408
}
386409

387-
def apply[T](predefCode: String = "")(urlToResult: String => T): T = {
410+
def apply[T](predefCode: String = "", runBeforeCode: Seq[String] = Seq.empty)(urlToResult: String => T): T = {
388411
val additionalClasspathEntryMaybe: Option[Path] =
389412
if (predefCode.trim.isEmpty) None
390413
else {
@@ -395,7 +418,7 @@ object Fixture {
395418
Files.delete(predefFile)
396419
predefClassfiles.toOption
397420
}
398-
val embeddedRepl = new EmbeddedRepl(compilerArgs(additionalClasspathEntryMaybe))
421+
val embeddedRepl = new EmbeddedRepl(compilerArgs(additionalClasspathEntryMaybe), runBeforeCode)
399422

400423
val host = "localhost"
401424
val port = 8081

0 commit comments

Comments
 (0)