From 2cc952cd3eec72493baf4631c0cb3e6d3bf404c8 Mon Sep 17 00:00:00 2001 From: Andy Scott Date: Wed, 19 Feb 2020 13:28:45 -0800 Subject: [PATCH 01/16] clean up workers --- scala/scalafmt/BUILD | 4 +- ...lafmtRunner.scala => ScalafmtWorker.scala} | 20 +-- scala_proto/scala_proto_toolchain.bzl | 2 +- .../instrumenter/JacocoInstrumenter.java | 27 ++-- src/java/io/bazel/rulesscala/scalac/BUILD | 5 +- .../rulesscala/scalac/ScalaCInvoker.java | 47 ------- ...ScalacProcessor.java => ScalacWorker.java} | 49 ++++++-- src/java/io/bazel/rulesscala/worker/BUILD | 9 +- .../rulesscala/worker/GenericWorker.java | 117 ------------------ .../io/bazel/rulesscala/worker/Processor.java | 7 -- .../io/bazel/rulesscala/worker/Worker.java | 114 +++++++++++++++++ src/scala/scripts/BUILD | 16 +-- src/scala/scripts/PBGenerateRequest.scala | 18 +-- ...aPBGenerator.scala => ScalaPBWorker.scala} | 24 +--- ...ogeGenerator.scala => ScroogeWorker.scala} | 32 +---- twitter_scrooge/twitter_scrooge.bzl | 2 +- 16 files changed, 202 insertions(+), 291 deletions(-) rename scala/scalafmt/scalafmt/{ScalafmtRunner.scala => ScalafmtWorker.scala} (71%) delete mode 100644 src/java/io/bazel/rulesscala/scalac/ScalaCInvoker.java rename src/java/io/bazel/rulesscala/scalac/{ScalacProcessor.java => ScalacWorker.java} (88%) delete mode 100644 src/java/io/bazel/rulesscala/worker/GenericWorker.java delete mode 100644 src/java/io/bazel/rulesscala/worker/Processor.java create mode 100644 src/java/io/bazel/rulesscala/worker/Worker.java rename src/scala/scripts/{ScalaPBGenerator.scala => ScalaPBWorker.scala} (80%) rename src/scala/scripts/{TwitterScroogeGenerator.scala => ScroogeWorker.scala} (81%) diff --git a/scala/scalafmt/BUILD b/scala/scalafmt/BUILD index 1a66f0fd3..942a54b7b 100644 --- a/scala/scalafmt/BUILD +++ b/scala/scalafmt/BUILD @@ -14,8 +14,8 @@ filegroup( scala_binary( name = "scalafmt", - srcs = ["scalafmt/ScalafmtRunner.scala"], - main_class = "io.bazel.rules_scala.scalafmt.ScalafmtRunner", + srcs = ["scalafmt/ScalafmtWorker.scala"], + main_class = "io.bazel.rules_scala.scalafmt.ScalafmtWorker", visibility = ["//visibility:public"], deps = [ "//src/java/io/bazel/rulesscala/worker", diff --git a/scala/scalafmt/scalafmt/ScalafmtRunner.scala b/scala/scalafmt/scalafmt/ScalafmtWorker.scala similarity index 71% rename from scala/scalafmt/scalafmt/ScalafmtRunner.scala rename to scala/scalafmt/scalafmt/ScalafmtWorker.scala index 5a4a870f1..6af48bae8 100644 --- a/scala/scalafmt/scalafmt/ScalafmtRunner.scala +++ b/scala/scalafmt/scalafmt/ScalafmtWorker.scala @@ -1,6 +1,6 @@ package io.bazel.rules_scala.scalafmt -import io.bazel.rulesscala.worker.{GenericWorker, Processor}; +import io.bazel.rulesscala.worker.Worker import java.io.File import java.nio.file.Files import org.scalafmt.Scalafmt @@ -10,21 +10,13 @@ import scala.annotation.tailrec import scala.collection.JavaConverters._ import scala.io.Codec -object ScalafmtRunner extends GenericWorker(new ScalafmtProcessor) { - def main(args: Array[String]) { - try run(args) - catch { - case x: Exception => - x.printStackTrace() - System.exit(1) - } - } -} +object ScalafmtWorker extends Worker.Interface { + + def main(args: Array[String]): Unit = Worker.workerMain(args, ScalafmtWorker) -class ScalafmtProcessor extends Processor { - def processRequest(args: java.util.List[String]) { + def work(args: Array[String]) { val argName = List("config", "input", "output") - val argFile = args.asScala.map{x => new File(x)} + val argFile = args.map{x => new File(x)} val namespace = argName.zip(argFile).toMap val source = FileOps.readFile(namespace.getOrElse("input", new File("")))(Codec.UTF8) diff --git a/scala_proto/scala_proto_toolchain.bzl b/scala_proto/scala_proto_toolchain.bzl index a733a2fbf..d13aa2b47 100644 --- a/scala_proto/scala_proto_toolchain.bzl +++ b/scala_proto/scala_proto_toolchain.bzl @@ -29,7 +29,7 @@ scala_proto_toolchain = rule( "code_generator": attr.label( executable = True, cfg = "host", - default = Label("@io_bazel_rules_scala//src/scala/scripts:scalapb_generator"), + default = Label("@io_bazel_rules_scala//src/scala/scripts:scalapb_worker"), allow_files = True, ), "named_generators": attr.string_dict(), diff --git a/src/java/io/bazel/rulesscala/coverage/instrumenter/JacocoInstrumenter.java b/src/java/io/bazel/rulesscala/coverage/instrumenter/JacocoInstrumenter.java index 86c450b63..cba85e893 100644 --- a/src/java/io/bazel/rulesscala/coverage/instrumenter/JacocoInstrumenter.java +++ b/src/java/io/bazel/rulesscala/coverage/instrumenter/JacocoInstrumenter.java @@ -1,8 +1,7 @@ package io.bazel.rulesscala.coverage.instrumenter; import io.bazel.rulesscala.jar.JarCreator; -import io.bazel.rulesscala.worker.GenericWorker; -import io.bazel.rulesscala.worker.Processor; +import io.bazel.rulesscala.worker.Worker; import java.io.BufferedInputStream; import java.io.BufferedOutputStream; @@ -26,28 +25,18 @@ import org.jacoco.core.instr.Instrumenter; import org.jacoco.core.runtime.OfflineInstrumentationAccessGenerator; -public final class JacocoInstrumenter implements Processor { +public final class JacocoInstrumenter implements Worker.Interface { - public static void main(String[] args) throws Exception { - (new Worker()).run(args); - } - - private static final class Worker extends GenericWorker { - public Worker() { - super(new JacocoInstrumenter()); - } + public static void main(String args[]) throws Exception { + Worker.workerMain(args, new JacocoInstrumenter()); } @Override - public void processRequest(List < String > args) { + public void work(String args[]) throws Exception { Instrumenter jacoco = new Instrumenter(new OfflineInstrumentationAccessGenerator()); - args.forEach(arg -> { - try { - processArg(jacoco, arg); - } catch (final Exception e) { - throw new RuntimeException(e); - } - }); + for (String arg : args) { + processArg(jacoco, arg); + } } private void processArg(Instrumenter jacoco, String arg) throws Exception { diff --git a/src/java/io/bazel/rulesscala/scalac/BUILD b/src/java/io/bazel/rulesscala/scalac/BUILD index 2cf4ded22..1fbb9d1be 100644 --- a/src/java/io/bazel/rulesscala/scalac/BUILD +++ b/src/java/io/bazel/rulesscala/scalac/BUILD @@ -17,7 +17,7 @@ java_binary( "-source 1.8", "-target 1.8", ], - main_class = "io.bazel.rulesscala.scalac.ScalaCInvoker", + main_class = "io.bazel.rulesscala.scalac.ScalacWorker", visibility = ["//visibility:public"], deps = [ ":exported_scalac_repositories_from_toolchain_to_jvm", @@ -33,8 +33,7 @@ filegroup( srcs = [ "CompileOptions.java", "Resource.java", - "ScalaCInvoker.java", - "ScalacProcessor.java", + "ScalacWorker.java", ], visibility = ["//visibility:public"], ) diff --git a/src/java/io/bazel/rulesscala/scalac/ScalaCInvoker.java b/src/java/io/bazel/rulesscala/scalac/ScalaCInvoker.java deleted file mode 100644 index ab60f95f0..000000000 --- a/src/java/io/bazel/rulesscala/scalac/ScalaCInvoker.java +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright 2014 The Bazel Authors. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package io.bazel.rulesscala.scalac; - -import io.bazel.rulesscala.worker.GenericWorker; -import java.io.PrintStream; -import scala.Console$; - -/** - * This is our entry point to producing a scala target this can act as one of Bazel's persistant - * workers. - */ -public class ScalaCInvoker extends GenericWorker { - public ScalaCInvoker() { - super(new ScalacProcessor()); - } - - @Override - protected void setupOutput(PrintStream ps) { - System.setOut(ps); - System.setErr(ps); - Console$.MODULE$.setErrDirect(ps); - Console$.MODULE$.setOutDirect(ps); - } - - public static void main(String[] args) { - try { - GenericWorker w = new ScalaCInvoker(); - w.run(args); - } catch (Exception ex) { - ex.printStackTrace(); - System.exit(1); - } - } -} diff --git a/src/java/io/bazel/rulesscala/scalac/ScalacProcessor.java b/src/java/io/bazel/rulesscala/scalac/ScalacWorker.java similarity index 88% rename from src/java/io/bazel/rulesscala/scalac/ScalacProcessor.java rename to src/java/io/bazel/rulesscala/scalac/ScalacWorker.java index fe5b7721f..541f7079a 100644 --- a/src/java/io/bazel/rulesscala/scalac/ScalacProcessor.java +++ b/src/java/io/bazel/rulesscala/scalac/ScalacWorker.java @@ -1,8 +1,7 @@ package io.bazel.rulesscala.scalac; import io.bazel.rulesscala.jar.JarCreator; -import io.bazel.rulesscala.worker.GenericWorker; -import io.bazel.rulesscala.worker.Processor; +import io.bazel.rulesscala.worker.Worker; import java.io.File; import java.io.FileNotFoundException; import java.io.FileOutputStream; @@ -28,7 +27,7 @@ import scala.tools.nsc.MainClass; import scala.tools.nsc.reporters.ConsoleReporter; -class ScalacProcessor implements Processor { +class ScalacWorker implements Worker.Interface { private static boolean isWindows = System.getProperty("os.name").toLowerCase().contains("windows"); /** This is the reporter field for scalac, which we want to access */ @@ -43,11 +42,15 @@ class ScalacProcessor implements Processor { } } + public static void main(String args[]) throws Exception { + Worker.workerMain(args, new ScalacWorker()); + } + @Override - public void processRequest(List args) throws Exception { + public void work(String[] args) throws Exception { Path tmpPath = null; try { - CompileOptions ops = new CompileOptions(args); + CompileOptions ops = new CompileOptions(Arrays.asList(args)); Path outputPath = FileSystems.getDefault().getPath(ops.outputName); tmpPath = Files.createTempDirectory(outputPath.getParent(), "tmp"); @@ -63,7 +66,7 @@ public void processRequest(List args) throws Exception { String[] scalaSources = collectSrcJarSources(ops.files, scalaJarFiles, javaJarFiles); - String[] javaSources = GenericWorker.appendToString(ops.javaFiles, javaJarFiles); + String[] javaSources = appendToString(ops.javaFiles, javaJarFiles); if (scalaSources.length == 0 && javaSources.length == 0) { throw new RuntimeException("Must have input files from either source jars or local files."); } @@ -95,8 +98,8 @@ public void processRequest(List args) throws Exception { private static String[] collectSrcJarSources( String[] files, List scalaJarFiles, List javaJarFiles) { - String[] scalaSources = GenericWorker.appendToString(files, scalaJarFiles); - return GenericWorker.appendToString(scalaSources, javaJarFiles); + String[] scalaSources = appendToString(files, scalaJarFiles); + return appendToString(scalaSources, javaJarFiles); } private static List filterFilesByExtension(List files, String extension) { @@ -166,7 +169,7 @@ private static boolean matchesFileExtensions(String fileName, String[] extension } private static String[] encodeBazelTargets(String[] targets) { - return Arrays.stream(targets).map(ScalacProcessor::encodeBazelTarget).toArray(String[]::new); + return Arrays.stream(targets).map(ScalacWorker::encodeBazelTarget).toArray(String[]::new); } private static String encodeBazelTarget(String target) { @@ -223,7 +226,7 @@ private static void compileScalaSources(CompileOptions ops, String[] scalaSource String[] constParams = {"-classpath", ops.classpath, "-d", tmpPath.toString()}; String[] compilerArgs = - GenericWorker.merge(ops.scalaOpts, ops.pluginArgs, constParams, pluginParams, scalaSources); + merge(ops.scalaOpts, ops.pluginArgs, constParams, pluginParams, scalaSources); MainClass comp = new MainClass(); long start = System.currentTimeMillis(); @@ -312,4 +315,30 @@ private static void copyResourceJars(String[] resourceJars, Path dest) throws IO extractJar(jarPath, dest.toString(), null); } } + + private static String[] appendToString(String[] init, List rest) { + String[] tmp = new String[init.length + rest.size()]; + System.arraycopy(init, 0, tmp, 0, init.length); + int baseIdx = init.length; + for (T t : rest) { + tmp[baseIdx] = t.toString(); + baseIdx += 1; + } + return tmp; + } + + private static String[] merge(String[]... arrays) { + int totalLength = 0; + for (String[] arr : arrays) { + totalLength += arr.length; + } + + String[] result = new String[totalLength]; + int offset = 0; + for (String[] arr : arrays) { + System.arraycopy(arr, 0, result, offset, arr.length); + offset += arr.length; + } + return result; + } } diff --git a/src/java/io/bazel/rulesscala/worker/BUILD b/src/java/io/bazel/rulesscala/worker/BUILD index 61c69e98d..6a08c9643 100644 --- a/src/java/io/bazel/rulesscala/worker/BUILD +++ b/src/java/io/bazel/rulesscala/worker/BUILD @@ -1,12 +1,7 @@ -load("@rules_java//java:defs.bzl", "java_library") - java_library( name = "worker", - srcs = [ - "GenericWorker.java", - "Processor.java", - ], - visibility = ["//visibility:public"], + srcs = ["Worker.java"], + visibility = ["//:__subpackages__"], deps = [ "//third_party/bazel/src/main/protobuf:worker_protocol_java_proto", ], diff --git a/src/java/io/bazel/rulesscala/worker/GenericWorker.java b/src/java/io/bazel/rulesscala/worker/GenericWorker.java deleted file mode 100644 index d26bed200..000000000 --- a/src/java/io/bazel/rulesscala/worker/GenericWorker.java +++ /dev/null @@ -1,117 +0,0 @@ -package io.bazel.rulesscala.worker; - -import static java.nio.charset.StandardCharsets.UTF_8; - -import com.google.devtools.build.lib.worker.WorkerProtocol.WorkRequest; -import com.google.devtools.build.lib.worker.WorkerProtocol.WorkResponse; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.PrintStream; -import java.nio.file.Files; -import java.nio.file.Paths; -import java.util.Arrays; -import java.util.List; - -public class GenericWorker { - protected final Processor processor; - - public GenericWorker(Processor p) { - processor = p; - } - - protected void setupOutput(PrintStream ps) { - System.setOut(ps); - System.setErr(ps); - } - - // Mostly lifted from bazel - private void runPersistentWorker() throws IOException { - PrintStream originalStdOut = System.out; - PrintStream originalStdErr = System.err; - - while (true) { - try { - WorkRequest request = WorkRequest.parseDelimitedFrom(System.in); - if (request == null) { - break; - } - ByteArrayOutputStream baos = new ByteArrayOutputStream(); - int exitCode = 0; - - try (PrintStream ps = new PrintStream(baos)) { - setupOutput(ps); - - try { - processor.processRequest(request.getArgumentsList()); - } catch (Exception e) { - e.printStackTrace(); - exitCode = 1; - } - } finally { - System.setOut(originalStdOut); - System.setErr(originalStdErr); - } - - WorkResponse.newBuilder() - .setOutput(baos.toString()) - .setExitCode(exitCode) - .build() - .writeDelimitedTo(System.out); - System.out.flush(); - } finally { - System.gc(); - } - } - } - - public static String[] appendToString(String[] init, List rest) { - String[] tmp = new String[init.length + rest.size()]; - System.arraycopy(init, 0, tmp, 0, init.length); - int baseIdx = init.length; - for (T t : rest) { - tmp[baseIdx] = t.toString(); - baseIdx += 1; - } - return tmp; - } - - public static String[] merge(String[]... arrays) { - int totalLength = 0; - for (String[] arr : arrays) { - totalLength += arr.length; - } - - String[] result = new String[totalLength]; - int offset = 0; - for (String[] arr : arrays) { - System.arraycopy(arr, 0, result, offset, arr.length); - offset += arr.length; - } - return result; - } - - private boolean contains(String[] args, String s) { - for (String str : args) { - if (str.equals(s)) return true; - } - return false; - } - - private static List normalize(List args) throws IOException { - if (args.size() == 1 && args.get(0).startsWith("@")) { - return Files.readAllLines(Paths.get(args.get(0).substring(1)), UTF_8); - } else { - return args; - } - } - - /** This is expected to be called by a main method */ - public void run(String[] argArray) throws Exception { - if (contains(argArray, "--persistent_worker")) { - runPersistentWorker(); - } else { - List args = Arrays.asList(argArray); - processor.processRequest(normalize(args)); - } - } -} diff --git a/src/java/io/bazel/rulesscala/worker/Processor.java b/src/java/io/bazel/rulesscala/worker/Processor.java deleted file mode 100644 index 126bcb666..000000000 --- a/src/java/io/bazel/rulesscala/worker/Processor.java +++ /dev/null @@ -1,7 +0,0 @@ -package io.bazel.rulesscala.worker; - -import java.util.List; - -public interface Processor { - void processRequest(List args) throws Exception; -} diff --git a/src/java/io/bazel/rulesscala/worker/Worker.java b/src/java/io/bazel/rulesscala/worker/Worker.java new file mode 100644 index 000000000..6153e4518 --- /dev/null +++ b/src/java/io/bazel/rulesscala/worker/Worker.java @@ -0,0 +1,114 @@ +package io.bazel.rulesscala.worker; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.PrintStream; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.lang.SecurityManager; +import java.security.Permission; +import java.util.ArrayList; +import java.util.LinkedList; +import java.util.List; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import java.nio.charset.StandardCharsets; + +import com.google.devtools.build.lib.worker.WorkerProtocol; + +public final class Worker { + + public static interface Interface { + public void work(String args[]) throws Exception; + } + + public static void workerMain(String workerArgs[], Interface workerInterface) throws Exception { + if (workerArgs.length > 0 && workerArgs[0].equals("--persistent_worker")) { + + System.setSecurityManager(new SecurityManager() { + @Override + public void checkPermission(Permission permission) { + Matcher matcher = exitPattern.matcher(permission.getName()); + if (matcher.find()) + throw new ExitTrapped(Integer.parseInt(matcher.group(1))); + } + }); + + InputStream stdin = System.in; + PrintStream stdout = System.out; + PrintStream stderr = System.err; + ByteArrayOutputStream outStream = new ByteArrayOutputStream(); + PrintStream out = new PrintStream(outStream); + + System.setIn(new ByteArrayInputStream(new byte[0])); + System.setOut(out); + System.setErr(out); + + try { + while (true) { + WorkerProtocol.WorkRequest request = + WorkerProtocol.WorkRequest.parseDelimitedFrom(stdin); + + int code = 0; + + try { + workerInterface.work(stringListToArray(request.getArgumentsList())); + } catch (ExitTrapped e) { + code = e.code; + } catch (Exception e) { + System.err.println(e.getMessage()); + e.printStackTrace(); + code = 1; + } + + WorkerProtocol.WorkResponse.newBuilder() + .setOutput(outStream.toString()) + .setExitCode(code) + .build() + .writeDelimitedTo(stdout); + + out.flush(); + outStream.reset(); + } + } catch (IOException e) { + } finally { + System.setIn(stdin); + System.setOut(stdout); + System.setErr(stderr); + } + } else { + String[] args; + if (workerArgs.length == 1 && workerArgs[0].startsWith("@")) { + args = stringListToArray(Files.readAllLines(Paths.get(workerArgs[0].substring(1)), StandardCharsets.UTF_8)); + } else { + args = workerArgs; + } + workerInterface.work(workerArgs); + } + } + + private static class ExitTrapped extends RuntimeException { + final int code; + ExitTrapped(int code) { + super(); + this.code = code; + } + } + + private static Pattern exitPattern = + Pattern.compile("exitVM\\.(-?\\d+)"); + + private static String[] stringListToArray(List argList) { + int numArgs = argList.size(); + String[] args = new String[numArgs]; + for (int i = 0; i < numArgs; i++) { + args[i] = argList.get(i); + } + return args; + } +} diff --git a/src/scala/scripts/BUILD b/src/scala/scripts/BUILD index 01d989ca7..0e10467b5 100644 --- a/src/scala/scripts/BUILD +++ b/src/scala/scripts/BUILD @@ -1,8 +1,8 @@ load("//scala:scala.bzl", "scala_binary", "scala_library") scala_library( - name = "generator_lib", - srcs = ["TwitterScroogeGenerator.scala"], + name = "scrooge_worker_lib", + srcs = ["ScroogeWorker.scala"], visibility = ["//visibility:public"], deps = [ "//external:io_bazel_rules_scala/dependency/thrift/scrooge_generator", @@ -14,11 +14,11 @@ scala_library( ) scala_binary( - name = "generator", + name = "scrooge_worker", main_class = "scripts.ScroogeWorker", visibility = ["//visibility:public"], deps = [ - ":generator_lib", + ":scrooge_worker_lib", ], ) @@ -29,8 +29,8 @@ scala_library( ) scala_library( - name = "scalapb_generator_lib", - srcs = ["ScalaPBGenerator.scala"], + name = "scalapb_worker_lib", + srcs = ["ScalaPBWorker.scala"], visibility = ["//visibility:public"], runtime_deps = [ "//external:io_bazel_rules_scala/dependency/com_google_protobuf/protobuf_java", @@ -47,10 +47,10 @@ scala_library( ) scala_binary( - name = "scalapb_generator", + name = "scalapb_worker", main_class = "scripts.ScalaPBWorker", visibility = ["//visibility:public"], deps = [ - ":scalapb_generator_lib", + ":scalapb_worker_lib", ], ) diff --git a/src/scala/scripts/PBGenerateRequest.scala b/src/scala/scripts/PBGenerateRequest.scala index 5eea6e146..fc1cea7a1 100644 --- a/src/scala/scripts/PBGenerateRequest.scala +++ b/src/scala/scripts/PBGenerateRequest.scala @@ -30,10 +30,10 @@ object PBGenerateRequest { } } - def from(args: java.util.List[String]): PBGenerateRequest = { - val jarOutput = args.get(0) - val protoFiles = args.get(4).split(':') - val includedProto = args.get(1).drop(1).split(':').distinct.map { e => + def from(args: Array[String]): PBGenerateRequest = { + val jarOutput = args(0) + val protoFiles = args(4).split(':') + val includedProto = args(1).drop(1).split(':').distinct.map { e => val p = e.split(',') // If its an empty string then it means we are local to the current repo for the key, no op (Some(p(0)).filter(_.nonEmpty), p(1)) @@ -42,13 +42,13 @@ object PBGenerateRequest { case (Some(k), v) if !protoFiles.contains(v) => (Paths.get(k), Paths.get(v)) }.toList - val flagOpt = args.get(2) match { + val flagOpt = args(2) match { case "-" => None case s if s.charAt(0) == '-' => Some(s.tail) //drop padding character case other => sys.error(s"expected a padding character of - (dash), but found: $other") } - val transitiveProtoPaths: List[String] = (args.get(3) match { + val transitiveProtoPaths: List[String] = (args(3) match { case "-" => Nil case s if s.charAt(0) == '-' => s.tail.split(':').toList //drop padding character case other => sys.error(s"expected a padding character of - (dash), but found: $other") @@ -58,7 +58,7 @@ object PBGenerateRequest { val scalaPBOutput = Files.createTempDirectory(tmp, "bazelscalapb") val flagPrefix = flagOpt.fold("")(_ + ":") - val namedGenerators = args.get(6).drop(1).split(',').filter(_.nonEmpty).map { e => + val namedGenerators = args(6).drop(1).split(',').filter(_.nonEmpty).map { e => val kv = e.split('=') (kv(0), kv(1)) } @@ -69,9 +69,9 @@ object PBGenerateRequest { val scalaPBArgs = outputSettings ::: (padWithProtoPathPrefix(transitiveProtoPaths) ++ protoFiles) - val protoc = Paths.get(args.get(5)) + val protoc = Paths.get(args(5)) - val extraJars = args.get(7).drop(1).split(':').filter(_.nonEmpty).distinct.map {e => Paths.get(e)}.toList + val extraJars = args(7).drop(1).split(':').filter(_.nonEmpty).distinct.map {e => Paths.get(e)}.toList new PBGenerateRequest(jarOutput, scalaPBOutput, scalaPBArgs, includedProto, protoc, namedGenerators, extraJars) } diff --git a/src/scala/scripts/ScalaPBGenerator.scala b/src/scala/scripts/ScalaPBWorker.scala similarity index 80% rename from src/scala/scripts/ScalaPBGenerator.scala rename to src/scala/scripts/ScalaPBWorker.scala index 9ed457d09..26858dda3 100644 --- a/src/scala/scripts/ScalaPBGenerator.scala +++ b/src/scala/scripts/ScalaPBWorker.scala @@ -5,7 +5,7 @@ import java.nio.file.{Path, FileAlreadyExistsException} import io.bazel.rulesscala.io_utils.DeleteRecursively import io.bazel.rulesscala.jar.JarCreator -import io.bazel.rulesscala.worker.{GenericWorker, Processor} +import io.bazel.rulesscala.worker.Worker import protocbridge.{ProtocBridge, ProtocCodeGenerator} import scala.collection.JavaConverters._ import scalapb.ScalaPbCodeGenerator @@ -14,33 +14,17 @@ import scalapb.{ScalaPBC, ScalaPbCodeGenerator, ScalaPbcException} import java.net.URLClassLoader import scala.util.{Try, Failure} -object ScalaPBWorker extends GenericWorker(new ScalaPBGenerator) { +object ScalaPBWorker extends Worker.Interface { - override protected def setupOutput(ps: PrintStream): Unit = { - System.setOut(ps) - System.setErr(ps) - Console.setErr(ps) - Console.setOut(ps) - } - - def main(args: Array[String]) { - try run(args) - catch { - case x: Exception => - x.printStackTrace() - System.exit(1) - } - } -} + def main(args: Array[String]): Unit = Worker.workerMain(args, ScalaPBWorker) -class ScalaPBGenerator extends Processor { def deleteDir(path: Path): Unit = try DeleteRecursively.run(path) catch { case e: Exception => sys.error(s"Problem while deleting path [$path], e.getMessage= ${e.getMessage}") } - def processRequest(args: java.util.List[String]) { + def work(args: Array[String]) { val extractRequestResult = PBGenerateRequest.from(args) val extraClassesClassLoader = new URLClassLoader(extractRequestResult.extraJars.map { e => val f = e.toFile diff --git a/src/scala/scripts/TwitterScroogeGenerator.scala b/src/scala/scripts/ScroogeWorker.scala similarity index 81% rename from src/scala/scripts/TwitterScroogeGenerator.scala rename to src/scala/scripts/ScroogeWorker.scala index 5b1f01a30..cb2839139 100644 --- a/src/scala/scripts/TwitterScroogeGenerator.scala +++ b/src/scala/scripts/ScroogeWorker.scala @@ -7,47 +7,27 @@ import io.bazel.rulesscala.io_utils.DeleteRecursively import java.io.{ File, PrintStream } import java.nio.file.{ Files, Path, Paths } import scala.collection.mutable.Buffer -import io.bazel.rulesscala.worker.{ GenericWorker, Processor } +import io.bazel.rulesscala.worker.Worker import scala.io.Source -/** - * This is our entry point to producing a scala target - * this can act as one of Bazel's persistant workers. - */ -object ScroogeWorker extends GenericWorker(new ScroogeGenerator) { - - override protected def setupOutput(ps: PrintStream): Unit = { - System.setOut(ps) - System.setErr(ps) - Console.setErr(ps) - Console.setOut(ps) - } +object ScroogeWorker extends Worker.Interface { - def main(args: Array[String]) { - try run(args) - catch { - case x: Exception => - x.printStackTrace() - System.exit(1) - } - } -} + def main(args: Array[String]): Unit = Worker.workerMain(args, ScroogeWorker) -class ScroogeGenerator extends Processor { def deleteDir(path: Path): Unit = try DeleteRecursively.run(path) catch { case e: Exception => () } - def processRequest(args: java.util.List[String]) { + def work(args: Array[String]) { def getIdx(i: Int): List[String] = { if (args.size > i) { // bazel worker arguments cannot be empty so we pad to ensure non-empty // and drop it off on the other side // https://github.com/bazelbuild/bazel/issues/3329 val workerArgPadLen = 1 // workerArgPadLen == "_".length - args.get(i) + args(i) .drop(workerArgPadLen) .split(':') .toList @@ -56,7 +36,7 @@ class ScroogeGenerator extends Processor { else Nil } - val jarOutput = args.get(0) + val jarOutput = args(0) // These are the files whose output we want val immediateThriftSrcJars = getIdx(1) // These are all of the files to include when generating scrooge diff --git a/twitter_scrooge/twitter_scrooge.bzl b/twitter_scrooge/twitter_scrooge.bzl index fa2482939..b3361db05 100644 --- a/twitter_scrooge/twitter_scrooge.bzl +++ b/twitter_scrooge/twitter_scrooge.bzl @@ -343,7 +343,7 @@ scrooge_aspect = aspect( "_pluck_scrooge_scala": attr.label( executable = True, cfg = "host", - default = Label("//src/scala/scripts:generator"), + default = Label("//src/scala/scripts:scrooge_worker"), allow_files = True, ), "_scalac": attr.label( From 995975dbc8505de25302d86d8f094ee3fbab374d Mon Sep 17 00:00:00 2001 From: Andy Scott Date: Wed, 19 Feb 2020 15:53:30 -0800 Subject: [PATCH 02/16] whoops! --- src/java/io/bazel/rulesscala/worker/Worker.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/java/io/bazel/rulesscala/worker/Worker.java b/src/java/io/bazel/rulesscala/worker/Worker.java index 6153e4518..753e42bd6 100644 --- a/src/java/io/bazel/rulesscala/worker/Worker.java +++ b/src/java/io/bazel/rulesscala/worker/Worker.java @@ -88,7 +88,7 @@ public void checkPermission(Permission permission) { } else { args = workerArgs; } - workerInterface.work(workerArgs); + workerInterface.work(args); } } From 16b4ac3de1fb33e8127a789df6a93527ed27fdb2 Mon Sep 17 00:00:00 2001 From: Andy Scott Date: Thu, 20 Feb 2020 11:12:27 -0800 Subject: [PATCH 03/16] git grep --name-only 'String args\[\]' | xargs sed -i 's|String args\[\]|String[] args|' --- .../rulesscala/coverage/instrumenter/JacocoInstrumenter.java | 4 ++-- src/java/io/bazel/rulesscala/scalac/ScalacWorker.java | 2 +- src/java/io/bazel/rulesscala/worker/Worker.java | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/java/io/bazel/rulesscala/coverage/instrumenter/JacocoInstrumenter.java b/src/java/io/bazel/rulesscala/coverage/instrumenter/JacocoInstrumenter.java index cba85e893..f13677ffa 100644 --- a/src/java/io/bazel/rulesscala/coverage/instrumenter/JacocoInstrumenter.java +++ b/src/java/io/bazel/rulesscala/coverage/instrumenter/JacocoInstrumenter.java @@ -27,12 +27,12 @@ public final class JacocoInstrumenter implements Worker.Interface { - public static void main(String args[]) throws Exception { + public static void main(String[] args) throws Exception { Worker.workerMain(args, new JacocoInstrumenter()); } @Override - public void work(String args[]) throws Exception { + public void work(String[] args) throws Exception { Instrumenter jacoco = new Instrumenter(new OfflineInstrumentationAccessGenerator()); for (String arg : args) { processArg(jacoco, arg); diff --git a/src/java/io/bazel/rulesscala/scalac/ScalacWorker.java b/src/java/io/bazel/rulesscala/scalac/ScalacWorker.java index 541f7079a..3cd7b715c 100644 --- a/src/java/io/bazel/rulesscala/scalac/ScalacWorker.java +++ b/src/java/io/bazel/rulesscala/scalac/ScalacWorker.java @@ -42,7 +42,7 @@ class ScalacWorker implements Worker.Interface { } } - public static void main(String args[]) throws Exception { + public static void main(String[] args) throws Exception { Worker.workerMain(args, new ScalacWorker()); } diff --git a/src/java/io/bazel/rulesscala/worker/Worker.java b/src/java/io/bazel/rulesscala/worker/Worker.java index 753e42bd6..79af083f6 100644 --- a/src/java/io/bazel/rulesscala/worker/Worker.java +++ b/src/java/io/bazel/rulesscala/worker/Worker.java @@ -24,7 +24,7 @@ public final class Worker { public static interface Interface { - public void work(String args[]) throws Exception; + public void work(String[] args) throws Exception; } public static void workerMain(String workerArgs[], Interface workerInterface) throws Exception { From 13b282595af7fcb9983c8521b357ddd16df00576 Mon Sep 17 00:00:00 2001 From: Andy Scott Date: Fri, 21 Feb 2020 10:18:32 -0800 Subject: [PATCH 04/16] subclass ByteArrayOutputStream to allow shrinking of internal buffer --- .../io/bazel/rulesscala/worker/Worker.java | 32 ++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/src/java/io/bazel/rulesscala/worker/Worker.java b/src/java/io/bazel/rulesscala/worker/Worker.java index 79af083f6..5bb907df4 100644 --- a/src/java/io/bazel/rulesscala/worker/Worker.java +++ b/src/java/io/bazel/rulesscala/worker/Worker.java @@ -42,7 +42,7 @@ public void checkPermission(Permission permission) { InputStream stdin = System.in; PrintStream stdout = System.out; PrintStream stderr = System.err; - ByteArrayOutputStream outStream = new ByteArrayOutputStream(); + ByteArrayOutputStream outStream = new SmartByteArrayOutputStream(); PrintStream out = new PrintStream(outStream); System.setIn(new ByteArrayInputStream(new byte[0])); @@ -92,6 +92,36 @@ public void checkPermission(Permission permission) { } } + /** A ByteArrayOutputStream that sometimes shrinks its internal + * buffer during calls to `reset`. + * + * In contrast, a regular ByteArrayOutputStream will only ever + * grow its internal buffer. + * + * For an example of subclassing a ByteArrayOutputStream, see + * Spring's ResizableByteArrayOutputStream: + * https://github.com/spring-projects/spring-framework/blob/master/spring-core/src/main/java/org/springframework/util/ResizableByteArrayOutputStream.java + */ + private static class SmartByteArrayOutputStream extends ByteArrayOutputStream { + // ByteArrayOutputStream's defualt Size is 32, which is extremely small + // to capture stdout from any worker process. We choose a larger default. + private static final int DEFAULT_SIZE = 256; + + public SmartByteArrayOutputStream() { + super(DEFAULT_SIZE); + } + + @Override + public void reset() { + super.reset(); + // reallocate our internal buffer if we've gone over our + // desired idle size + if (this.buf.length > DEFAULT_SIZE) { + this.buf = new byte[DEFAULT_SIZE]; + } + } + } + private static class ExitTrapped extends RuntimeException { final int code; ExitTrapped(int code) { From 1fe70fc60617ea03365a89058008b4cc0042924b Mon Sep 17 00:00:00 2001 From: Andy Scott Date: Fri, 21 Feb 2020 10:18:49 -0800 Subject: [PATCH 05/16] gc after each work request --- src/java/io/bazel/rulesscala/worker/Worker.java | 1 + 1 file changed, 1 insertion(+) diff --git a/src/java/io/bazel/rulesscala/worker/Worker.java b/src/java/io/bazel/rulesscala/worker/Worker.java index 5bb907df4..51636fe96 100644 --- a/src/java/io/bazel/rulesscala/worker/Worker.java +++ b/src/java/io/bazel/rulesscala/worker/Worker.java @@ -74,6 +74,7 @@ public void checkPermission(Permission permission) { out.flush(); outStream.reset(); + System.gc(); } } catch (IOException e) { } finally { From c289a12de8495cc3563d9f3dfacdba0af6c16cd2 Mon Sep 17 00:00:00 2001 From: Andy Scott Date: Fri, 21 Feb 2020 10:22:33 -0800 Subject: [PATCH 06/16] split workerMain implementation into two private methods for readability --- .../io/bazel/rulesscala/worker/Worker.java | 121 +++++++++--------- 1 file changed, 64 insertions(+), 57 deletions(-) diff --git a/src/java/io/bazel/rulesscala/worker/Worker.java b/src/java/io/bazel/rulesscala/worker/Worker.java index 51636fe96..9c89effc2 100644 --- a/src/java/io/bazel/rulesscala/worker/Worker.java +++ b/src/java/io/bazel/rulesscala/worker/Worker.java @@ -29,68 +29,75 @@ public static interface Interface { public static void workerMain(String workerArgs[], Interface workerInterface) throws Exception { if (workerArgs.length > 0 && workerArgs[0].equals("--persistent_worker")) { + persistentWorkerMain(workerInterface); + } else { + ephemeralWorkerMain(workerArgs, workerInterface); + } + } - System.setSecurityManager(new SecurityManager() { - @Override - public void checkPermission(Permission permission) { - Matcher matcher = exitPattern.matcher(permission.getName()); - if (matcher.find()) - throw new ExitTrapped(Integer.parseInt(matcher.group(1))); - } - }); - - InputStream stdin = System.in; - PrintStream stdout = System.out; - PrintStream stderr = System.err; - ByteArrayOutputStream outStream = new SmartByteArrayOutputStream(); - PrintStream out = new PrintStream(outStream); - - System.setIn(new ByteArrayInputStream(new byte[0])); - System.setOut(out); - System.setErr(out); - - try { - while (true) { - WorkerProtocol.WorkRequest request = - WorkerProtocol.WorkRequest.parseDelimitedFrom(stdin); - - int code = 0; - - try { - workerInterface.work(stringListToArray(request.getArgumentsList())); - } catch (ExitTrapped e) { - code = e.code; - } catch (Exception e) { - System.err.println(e.getMessage()); - e.printStackTrace(); - code = 1; - } - - WorkerProtocol.WorkResponse.newBuilder() - .setOutput(outStream.toString()) - .setExitCode(code) - .build() - .writeDelimitedTo(stdout); - - out.flush(); - outStream.reset(); - System.gc(); + private static void persistentWorkerMain(Interface workerInterface) throws Exception { + System.setSecurityManager(new SecurityManager() { + @Override + public void checkPermission(Permission permission) { + Matcher matcher = exitPattern.matcher(permission.getName()); + if (matcher.find()) + throw new ExitTrapped(Integer.parseInt(matcher.group(1))); } - } catch (IOException e) { - } finally { - System.setIn(stdin); - System.setOut(stdout); - System.setErr(stderr); + }); + + InputStream stdin = System.in; + PrintStream stdout = System.out; + PrintStream stderr = System.err; + ByteArrayOutputStream outStream = new SmartByteArrayOutputStream(); + PrintStream out = new PrintStream(outStream); + + System.setIn(new ByteArrayInputStream(new byte[0])); + System.setOut(out); + System.setErr(out); + + try { + while (true) { + WorkerProtocol.WorkRequest request = + WorkerProtocol.WorkRequest.parseDelimitedFrom(stdin); + + int code = 0; + + try { + workerInterface.work(stringListToArray(request.getArgumentsList())); + } catch (ExitTrapped e) { + code = e.code; + } catch (Exception e) { + System.err.println(e.getMessage()); + e.printStackTrace(); + code = 1; + } + + WorkerProtocol.WorkResponse.newBuilder() + .setOutput(outStream.toString()) + .setExitCode(code) + .build() + .writeDelimitedTo(stdout); + + out.flush(); + outStream.reset(); + System.gc(); } + } catch (IOException e) { + } finally { + System.setIn(stdin); + System.setOut(stdout); + System.setErr(stderr); + } + } + + private static void ephemeralWorkerMain(String workerArgs[], Interface workerInterface) throws Exception { + String[] args; + if (workerArgs.length == 1 && workerArgs[0].startsWith("@")) { + args = stringListToArray(Files.readAllLines(Paths.get(workerArgs[0].substring(1)), StandardCharsets.UTF_8)); } else { - String[] args; - if (workerArgs.length == 1 && workerArgs[0].startsWith("@")) { - args = stringListToArray(Files.readAllLines(Paths.get(workerArgs[0].substring(1)), StandardCharsets.UTF_8)); - } else { - args = workerArgs; - } - workerInterface.work(args); + args = workerArgs; } + workerInterface.work(args); } /** A ByteArrayOutputStream that sometimes shrinks its internal From 9e8f6738427aaa622c0bb0fb8bac8c87ca617aca Mon Sep 17 00:00:00 2001 From: Andy Scott Date: Fri, 21 Feb 2020 10:24:07 -0800 Subject: [PATCH 07/16] add comment about stdin --- src/java/io/bazel/rulesscala/worker/Worker.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/java/io/bazel/rulesscala/worker/Worker.java b/src/java/io/bazel/rulesscala/worker/Worker.java index 9c89effc2..380478a83 100644 --- a/src/java/io/bazel/rulesscala/worker/Worker.java +++ b/src/java/io/bazel/rulesscala/worker/Worker.java @@ -51,7 +51,9 @@ public void checkPermission(Permission permission) { ByteArrayOutputStream outStream = new SmartByteArrayOutputStream(); PrintStream out = new PrintStream(outStream); + // We can't support stdin, so assign it to read from an empty buffer System.setIn(new ByteArrayInputStream(new byte[0])); + System.setOut(out); System.setErr(out); From aba90b524b2e48daa8eca568d37085c7aa14265d Mon Sep 17 00:00:00 2001 From: Andy Scott Date: Fri, 21 Feb 2020 11:37:01 -0800 Subject: [PATCH 08/16] add a test for the buffer --- src/java/io/bazel/rulesscala/worker/BUILD | 11 ++++++ .../io/bazel/rulesscala/worker/Worker.java | 8 +++-- .../bazel/rulesscala/worker/WorkerTest.java | 36 +++++++++++++++++++ 3 files changed, 53 insertions(+), 2 deletions(-) create mode 100644 src/java/io/bazel/rulesscala/worker/WorkerTest.java diff --git a/src/java/io/bazel/rulesscala/worker/BUILD b/src/java/io/bazel/rulesscala/worker/BUILD index 6a08c9643..c0e4eb327 100644 --- a/src/java/io/bazel/rulesscala/worker/BUILD +++ b/src/java/io/bazel/rulesscala/worker/BUILD @@ -6,3 +6,14 @@ java_library( "//third_party/bazel/src/main/protobuf:worker_protocol_java_proto", ], ) + +java_test( + name = "worker_test", + srcs = [ + "WorkerTest.java", + ], + test_class = "io.bazel.rulesscala.worker.WorkerTest", + deps = [ + ":worker", + ], +) diff --git a/src/java/io/bazel/rulesscala/worker/Worker.java b/src/java/io/bazel/rulesscala/worker/Worker.java index 380478a83..65c5f35dc 100644 --- a/src/java/io/bazel/rulesscala/worker/Worker.java +++ b/src/java/io/bazel/rulesscala/worker/Worker.java @@ -112,7 +112,7 @@ private static void ephemeralWorkerMain(String workerArgs[], Interface workerInt * Spring's ResizableByteArrayOutputStream: * https://github.com/spring-projects/spring-framework/blob/master/spring-core/src/main/java/org/springframework/util/ResizableByteArrayOutputStream.java */ - private static class SmartByteArrayOutputStream extends ByteArrayOutputStream { + static class SmartByteArrayOutputStream extends ByteArrayOutputStream { // ByteArrayOutputStream's defualt Size is 32, which is extremely small // to capture stdout from any worker process. We choose a larger default. private static final int DEFAULT_SIZE = 256; @@ -121,12 +121,16 @@ public SmartByteArrayOutputStream() { super(DEFAULT_SIZE); } + public boolean isOversized() { + return this.buf.length > DEFAULT_SIZE; + } + @Override public void reset() { super.reset(); // reallocate our internal buffer if we've gone over our // desired idle size - if (this.buf.length > DEFAULT_SIZE) { + if (this.isOversized()) { this.buf = new byte[DEFAULT_SIZE]; } } diff --git a/src/java/io/bazel/rulesscala/worker/WorkerTest.java b/src/java/io/bazel/rulesscala/worker/WorkerTest.java new file mode 100644 index 000000000..92d770a91 --- /dev/null +++ b/src/java/io/bazel/rulesscala/worker/WorkerTest.java @@ -0,0 +1,36 @@ +package io.bazel.rulesscala.worker; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import java.io.ByteArrayOutputStream; +import java.io.PrintStream; + +@RunWith(JUnit4.class) +public class WorkerTest { + + private static void fill(ByteArrayOutputStream baos, int amount) { + for (int i = 0; i < amount; i++) { + baos.write(0); + } + } + + @Test + public void testWriteReadAndReset() throws Exception { + Worker.SmartByteArrayOutputStream baos = new Worker.SmartByteArrayOutputStream(); + PrintStream out = new PrintStream(baos); + + out.print("hello, world"); + assert(baos.toString("UTF-8").equals("hello, world")); + assert(!baos.isOversized()); + + fill(baos, 300); + assert(baos.isOversized()); + baos.reset(); + + out.print("goodbye, world"); + assert(baos.toString("UTF-8").equals("goodbye, world")); + assert(!baos.isOversized()); + } +} From f92df24c47f6aef090172bfa3502020b8caf3ffb Mon Sep 17 00:00:00 2001 From: Andy Scott Date: Sat, 22 Feb 2020 16:55:48 -0800 Subject: [PATCH 09/16] improve tests --- src/java/io/bazel/rulesscala/worker/BUILD | 4 + .../io/bazel/rulesscala/worker/Worker.java | 6 +- .../bazel/rulesscala/worker/WorkerTest.java | 126 +++++++++++++++++- 3 files changed, 134 insertions(+), 2 deletions(-) diff --git a/src/java/io/bazel/rulesscala/worker/BUILD b/src/java/io/bazel/rulesscala/worker/BUILD index c0e4eb327..d55126c30 100644 --- a/src/java/io/bazel/rulesscala/worker/BUILD +++ b/src/java/io/bazel/rulesscala/worker/BUILD @@ -12,8 +12,12 @@ java_test( srcs = [ "WorkerTest.java", ], + jvm_flags = [ + "-Dcom.google.testing.junit.runner.shouldInstallTestSecurityManager=false", + ], test_class = "io.bazel.rulesscala.worker.WorkerTest", deps = [ ":worker", + "//third_party/bazel/src/main/protobuf:worker_protocol_java_proto", ], ) diff --git a/src/java/io/bazel/rulesscala/worker/Worker.java b/src/java/io/bazel/rulesscala/worker/Worker.java index 65c5f35dc..6f2ba8f5c 100644 --- a/src/java/io/bazel/rulesscala/worker/Worker.java +++ b/src/java/io/bazel/rulesscala/worker/Worker.java @@ -62,6 +62,10 @@ public void checkPermission(Permission permission) { WorkerProtocol.WorkRequest request = WorkerProtocol.WorkRequest.parseDelimitedFrom(stdin); + if (request == null) { + break; + } + int code = 0; try { @@ -136,7 +140,7 @@ public void reset() { } } - private static class ExitTrapped extends RuntimeException { + static class ExitTrapped extends RuntimeException { final int code; ExitTrapped(int code) { super(); diff --git a/src/java/io/bazel/rulesscala/worker/WorkerTest.java b/src/java/io/bazel/rulesscala/worker/WorkerTest.java index 92d770a91..d877d72e8 100644 --- a/src/java/io/bazel/rulesscala/worker/WorkerTest.java +++ b/src/java/io/bazel/rulesscala/worker/WorkerTest.java @@ -1,15 +1,97 @@ package io.bazel.rulesscala.worker; +import org.junit.AfterClass; +import org.junit.BeforeClass; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import java.io.ByteArrayOutputStream; +import java.io.InputStream; +import java.io.PipedInputStream; +import java.io.PipedOutputStream; import java.io.PrintStream; +import java.lang.SecurityManager; + +import com.google.devtools.build.lib.worker.WorkerProtocol; @RunWith(JUnit4.class) public class WorkerTest { + @Test + public void testEphemeralWorkerSystemExit() throws Exception { + + // An ephemeral worker behaves like a regular main method, + // so we expect the worker to system exit normally + + Worker.Interface worker = new Worker.Interface() { + @Override + public void work(String[] args) { + System.exit(99); + } + }; + + // we expect ephemeral workers to just exit normally + int code = assertThrows(Worker.ExitTrapped.class, () -> + Worker.workerMain(new String[]{}, worker)).code; + + assert(code == 99); + } + + @Test + public void testPersistentWorkerSystemExit() throws Exception { + + // We're going to spin up a persistent worker and run a single + // work request. We expect System exists to impact the worker + // request lifecycle without exiting the overall worker + // process. + + Worker.Interface worker = new Worker.Interface() { + @Override + public void work(String[] args) { + // we should see this print statement + System.out.println("before exit"); + System.exit(99); + // we should not see this print statement + System.out.println("after exit"); + } + }; + + try ( + PipedInputStream in = new PipedInputStream(); + PipedOutputStream outToIn = new PipedOutputStream(in); + + ByteArrayOutputStream out = new ByteArrayOutputStream(); + ) { + + InputStream stdin = System.in; + PrintStream stdout = System.out; + PrintStream stderr = System.err; + + System.setIn(in); + System.setOut(new PrintStream(out)); + + WorkerProtocol.WorkRequest.newBuilder() + .build() + .writeDelimitedTo(outToIn); + + // otherwise the worker will poll indefinitely + outToIn.close(); + + Worker.workerMain(new String[]{"--persistent_worker"}, worker); + + System.setIn(stdin); + System.setOut(stdout); + System.setErr(stderr); + + String outString = out.toString("UTF-8"); + // check to make sure the before statement printed + assert(outString.contains("before")); + // check to make sure the after statement did not print + assert(!outString.contains("after")); + } + } + private static void fill(ByteArrayOutputStream baos, int amount) { for (int i = 0; i < amount; i++) { baos.write(0); @@ -17,7 +99,7 @@ private static void fill(ByteArrayOutputStream baos, int amount) { } @Test - public void testWriteReadAndReset() throws Exception { + public void testBufferWriteReadAndReset() throws Exception { Worker.SmartByteArrayOutputStream baos = new Worker.SmartByteArrayOutputStream(); PrintStream out = new PrintStream(baos); @@ -33,4 +115,46 @@ public void testWriteReadAndReset() throws Exception { assert(baos.toString("UTF-8").equals("goodbye, world")); assert(!baos.isOversized()); } + + @AfterClass + public static void teardown() { + // Persistent workers install a security manager. We need to + // reset it here so that our own process can exit! + System.setSecurityManager(null); + } + + // Copied/modified from Bazel's MoreAsserts + // + // Note: this goes away soon-ish, as JUnit 4.13 was recently + // released and includes assertThrows + public static T assertThrows( + Class expectedThrowable, + ThrowingRunnable runnable) + { + try { + runnable.run(); + } catch (Throwable actualThrown) { + if (expectedThrowable.isInstance(actualThrown)) { + @SuppressWarnings("unchecked") + T retVal = (T) actualThrown; + return retVal; + } else { + throw new AssertionError( + String.format( + "expected %s to be thrown, but %s was thrown", + expectedThrowable.getSimpleName(), + actualThrown.getClass().getSimpleName()), + actualThrown); + } + } + String mismatchMessage = String.format( + "expected %s to be thrown, but nothing was thrown", + expectedThrowable.getSimpleName()); + throw new AssertionError(mismatchMessage); + } + + // see note on assertThrows + public interface ThrowingRunnable { + void run() throws Throwable; + } } From 87104f7464fff437641231d943ccb1d3971de547 Mon Sep 17 00:00:00 2001 From: Andy Scott Date: Sat, 22 Feb 2020 17:09:37 -0800 Subject: [PATCH 10/16] wrap gc and friends in the finally clause --- .../io/bazel/rulesscala/worker/Worker.java | 65 +++++++++++-------- 1 file changed, 37 insertions(+), 28 deletions(-) diff --git a/src/java/io/bazel/rulesscala/worker/Worker.java b/src/java/io/bazel/rulesscala/worker/Worker.java index 6f2ba8f5c..b6c89d9dc 100644 --- a/src/java/io/bazel/rulesscala/worker/Worker.java +++ b/src/java/io/bazel/rulesscala/worker/Worker.java @@ -35,7 +35,7 @@ public static void workerMain(String workerArgs[], Interface workerInterface) th } } - private static void persistentWorkerMain(Interface workerInterface) throws Exception { + private static void persistentWorkerMain(Interface workerInterface) { System.setSecurityManager(new SecurityManager() { @Override public void checkPermission(Permission permission) { @@ -59,36 +59,45 @@ public void checkPermission(Permission permission) { try { while (true) { - WorkerProtocol.WorkRequest request = - WorkerProtocol.WorkRequest.parseDelimitedFrom(stdin); - - if (request == null) { - break; - } - - int code = 0; - try { - workerInterface.work(stringListToArray(request.getArgumentsList())); - } catch (ExitTrapped e) { - code = e.code; - } catch (Exception e) { - System.err.println(e.getMessage()); - e.printStackTrace(); - code = 1; + WorkerProtocol.WorkRequest request = + WorkerProtocol.WorkRequest.parseDelimitedFrom(stdin); + + // The request will be null if stdin is closed. We're + // not sure if this happens in TheRealWorld™ but it is + // useful for testing (to shut down a persistent + // worker process). + if (request == null) { + break; + } + + int code = 0; + + try { + workerInterface.work(stringListToArray(request.getArgumentsList())); + } catch (ExitTrapped e) { + code = e.code; + } catch (Exception e) { + System.err.println(e.getMessage()); + e.printStackTrace(); + code = 1; + } + + WorkerProtocol.WorkResponse.newBuilder() + .setOutput(outStream.toString()) + .setExitCode(code) + .build() + .writeDelimitedTo(stdout); + + } catch (IOException e) { + // for now we swallow IOExceptions when + // reading/writing proto + } finally { + out.flush(); + outStream.reset(); + System.gc(); } - - WorkerProtocol.WorkResponse.newBuilder() - .setOutput(outStream.toString()) - .setExitCode(code) - .build() - .writeDelimitedTo(stdout); - - out.flush(); - outStream.reset(); - System.gc(); } - } catch (IOException e) { } finally { System.setIn(stdin); System.setOut(stdout); From 118581ffd55e664d1c9348967d6d3277d9da3c65 Mon Sep 17 00:00:00 2001 From: Andy Scott Date: Sat, 22 Feb 2020 17:17:53 -0800 Subject: [PATCH 11/16] comments --- src/java/io/bazel/rulesscala/worker/Worker.java | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/java/io/bazel/rulesscala/worker/Worker.java b/src/java/io/bazel/rulesscala/worker/Worker.java index b6c89d9dc..26f60fa4b 100644 --- a/src/java/io/bazel/rulesscala/worker/Worker.java +++ b/src/java/io/bazel/rulesscala/worker/Worker.java @@ -27,6 +27,11 @@ public static interface Interface { public void work(String[] args) throws Exception; } + /** The entry point for all workers. + * + * This should be the only thing called by a main method in a + * worker process. + */ public static void workerMain(String workerArgs[], Interface workerInterface) throws Exception { if (workerArgs.length > 0 && workerArgs[0].equals("--persistent_worker")) { persistentWorkerMain(workerInterface); @@ -35,6 +40,7 @@ public static void workerMain(String workerArgs[], Interface workerInterface) th } } + /** The main loop for persistent worker processes */ private static void persistentWorkerMain(Interface workerInterface) { System.setSecurityManager(new SecurityManager() { @Override @@ -105,6 +111,8 @@ public void checkPermission(Permission permission) { } } + /** The single pass runner for ephemeral (non-persistent) worker + * processes */ private static void ephemeralWorkerMain(String workerArgs[], Interface workerInterface) throws Exception { String[] args; if (workerArgs.length == 1 && workerArgs[0].startsWith("@")) { From 8ba11b51a94c31534730c553153c205fa2f16d8d Mon Sep 17 00:00:00 2001 From: Andy Scott Date: Sat, 22 Feb 2020 17:19:10 -0800 Subject: [PATCH 12/16] make the worker public --- src/java/io/bazel/rulesscala/worker/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/java/io/bazel/rulesscala/worker/BUILD b/src/java/io/bazel/rulesscala/worker/BUILD index d55126c30..60ba51faf 100644 --- a/src/java/io/bazel/rulesscala/worker/BUILD +++ b/src/java/io/bazel/rulesscala/worker/BUILD @@ -1,7 +1,7 @@ java_library( name = "worker", srcs = ["Worker.java"], - visibility = ["//:__subpackages__"], + visibility = ["//visibility:public"], deps = [ "//third_party/bazel/src/main/protobuf:worker_protocol_java_proto", ], From 2bc7a938c3a10fb97777a71c754c1c1590044cac Mon Sep 17 00:00:00 2001 From: Andy Scott Date: Sat, 22 Feb 2020 17:52:55 -0800 Subject: [PATCH 13/16] cleanup --- .../io/bazel/rulesscala/worker/Worker.java | 4 +-- .../bazel/rulesscala/worker/WorkerTest.java | 33 ++++++++++--------- 2 files changed, 19 insertions(+), 18 deletions(-) diff --git a/src/java/io/bazel/rulesscala/worker/Worker.java b/src/java/io/bazel/rulesscala/worker/Worker.java index 26f60fa4b..517550933 100644 --- a/src/java/io/bazel/rulesscala/worker/Worker.java +++ b/src/java/io/bazel/rulesscala/worker/Worker.java @@ -6,9 +6,10 @@ import java.io.InputStream; import java.io.OutputStream; import java.io.PrintStream; +import java.lang.SecurityManager; +import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Paths; -import java.lang.SecurityManager; import java.security.Permission; import java.util.ArrayList; import java.util.LinkedList; @@ -17,7 +18,6 @@ import java.util.regex.Pattern; import java.util.stream.Collectors; import java.util.stream.Stream; -import java.nio.charset.StandardCharsets; import com.google.devtools.build.lib.worker.WorkerProtocol; diff --git a/src/java/io/bazel/rulesscala/worker/WorkerTest.java b/src/java/io/bazel/rulesscala/worker/WorkerTest.java index d877d72e8..b903e6ce4 100644 --- a/src/java/io/bazel/rulesscala/worker/WorkerTest.java +++ b/src/java/io/bazel/rulesscala/worker/WorkerTest.java @@ -31,7 +31,6 @@ public void work(String[] args) { } }; - // we expect ephemeral workers to just exit normally int code = assertThrows(Worker.ExitTrapped.class, () -> Worker.workerMain(new String[]{}, worker)).code; @@ -42,8 +41,8 @@ public void work(String[] args) { public void testPersistentWorkerSystemExit() throws Exception { // We're going to spin up a persistent worker and run a single - // work request. We expect System exists to impact the worker - // request lifecycle without exiting the overall worker + // work request. We expect System.exit calls to impact the + // worker request lifecycle without exiting the overall worker // process. Worker.Interface worker = new Worker.Interface() { @@ -51,32 +50,33 @@ public void testPersistentWorkerSystemExit() throws Exception { public void work(String[] args) { // we should see this print statement System.out.println("before exit"); - System.exit(99); + System.exit(100); // we should not see this print statement System.out.println("after exit"); } }; try ( - PipedInputStream in = new PipedInputStream(); - PipedOutputStream outToIn = new PipedOutputStream(in); + PipedInputStream workerIn = new PipedInputStream(); + PipedOutputStream outToWorkerIn = new PipedOutputStream(workerIn); - ByteArrayOutputStream out = new ByteArrayOutputStream(); + PipedOutputStream workerOut = new PipedOutputStream(); + PipedInputStream inFromWorkerOut = new PipedInputStream(workerOut); ) { InputStream stdin = System.in; PrintStream stdout = System.out; PrintStream stderr = System.err; - System.setIn(in); - System.setOut(new PrintStream(out)); + System.setIn(workerIn); + System.setOut(new PrintStream(workerOut)); WorkerProtocol.WorkRequest.newBuilder() .build() - .writeDelimitedTo(outToIn); + .writeDelimitedTo(outToWorkerIn); // otherwise the worker will poll indefinitely - outToIn.close(); + outToWorkerIn.close(); Worker.workerMain(new String[]{"--persistent_worker"}, worker); @@ -84,11 +84,12 @@ public void work(String[] args) { System.setOut(stdout); System.setErr(stderr); - String outString = out.toString("UTF-8"); - // check to make sure the before statement printed - assert(outString.contains("before")); - // check to make sure the after statement did not print - assert(!outString.contains("after")); + WorkerProtocol.WorkResponse response = + WorkerProtocol.WorkResponse.parseDelimitedFrom(inFromWorkerOut); + + assert(response.getOutput().contains("before")); + assert(response.getExitCode() == 100); + assert(!response.getOutput().contains("after")); } } From ac308e22f50d78efafd5cc6d9611ae5be96cdcce Mon Sep 17 00:00:00 2001 From: Andy Scott Date: Sat, 22 Feb 2020 18:00:57 -0800 Subject: [PATCH 14/16] comments --- src/java/io/bazel/rulesscala/worker/Worker.java | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/java/io/bazel/rulesscala/worker/Worker.java b/src/java/io/bazel/rulesscala/worker/Worker.java index 517550933..9ba6ddcec 100644 --- a/src/java/io/bazel/rulesscala/worker/Worker.java +++ b/src/java/io/bazel/rulesscala/worker/Worker.java @@ -21,6 +21,15 @@ import com.google.devtools.build.lib.worker.WorkerProtocol; +/** A base for JVM workers. + * + * This supports regular workers as well as persisent workers. It + * does not (yet) support multiplexed workers. + * + * Worker implementations should implement the `Worker.Interface` + * interface and provide a main method that calls `Worker.workerMain`. + * + */ public final class Worker { public static interface Interface { @@ -90,8 +99,8 @@ public void checkPermission(Permission permission) { } WorkerProtocol.WorkResponse.newBuilder() - .setOutput(outStream.toString()) .setExitCode(code) + .setOutput(outStream.toString()) .build() .writeDelimitedTo(stdout); From 2b020122f21948493792be6a8fa94c1280dee025 Mon Sep 17 00:00:00 2001 From: Andy Scott Date: Tue, 25 Feb 2020 08:52:08 -0800 Subject: [PATCH 15/16] add stdin test --- .../bazel/rulesscala/worker/WorkerTest.java | 127 +++++++++++++----- 1 file changed, 93 insertions(+), 34 deletions(-) diff --git a/src/java/io/bazel/rulesscala/worker/WorkerTest.java b/src/java/io/bazel/rulesscala/worker/WorkerTest.java index b903e6ce4..e0cf7d4c3 100644 --- a/src/java/io/bazel/rulesscala/worker/WorkerTest.java +++ b/src/java/io/bazel/rulesscala/worker/WorkerTest.java @@ -8,10 +8,13 @@ import java.io.ByteArrayOutputStream; import java.io.InputStream; +import java.io.IOException; import java.io.PipedInputStream; import java.io.PipedOutputStream; import java.io.PrintStream; +import java.lang.AutoCloseable; import java.lang.SecurityManager; +import java.util.Scanner; import com.google.devtools.build.lib.worker.WorkerProtocol; @@ -39,53 +42,33 @@ public void work(String[] args) { @Test public void testPersistentWorkerSystemExit() throws Exception { - // We're going to spin up a persistent worker and run a single // work request. We expect System.exit calls to impact the // worker request lifecycle without exiting the overall worker // process. - Worker.Interface worker = new Worker.Interface() { - @Override - public void work(String[] args) { - // we should see this print statement - System.out.println("before exit"); - System.exit(100); - // we should not see this print statement - System.out.println("after exit"); - } - }; - try ( - PipedInputStream workerIn = new PipedInputStream(); - PipedOutputStream outToWorkerIn = new PipedOutputStream(workerIn); - - PipedOutputStream workerOut = new PipedOutputStream(); - PipedInputStream inFromWorkerOut = new PipedInputStream(workerOut); + PersistentWorkerHelper helper = new PersistentWorkerHelper(); ) { - - InputStream stdin = System.in; - PrintStream stdout = System.out; - PrintStream stderr = System.err; - - System.setIn(workerIn); - System.setOut(new PrintStream(workerOut)); - WorkerProtocol.WorkRequest.newBuilder() .build() - .writeDelimitedTo(outToWorkerIn); + .writeDelimitedTo(helper.requestOut); - // otherwise the worker will poll indefinitely - outToWorkerIn.close(); - - Worker.workerMain(new String[]{"--persistent_worker"}, worker); + Worker.Interface worker = new Worker.Interface() { + @Override + public void work(String[] args) { + // we should see this print statement + System.out.println("before exit"); + System.exit(100); + // we should not see this print statement + System.out.println("after exit"); + } + }; - System.setIn(stdin); - System.setOut(stdout); - System.setErr(stderr); + helper.runWorker(worker); WorkerProtocol.WorkResponse response = - WorkerProtocol.WorkResponse.parseDelimitedFrom(inFromWorkerOut); + WorkerProtocol.WorkResponse.parseDelimitedFrom(helper.responseIn); assert(response.getOutput().contains("before")); assert(response.getExitCode() == 100); @@ -93,6 +76,82 @@ public void work(String[] args) { } } + @Test + public void testPersistentWorkerNoStdin() throws Exception { + try ( + PersistentWorkerHelper helper = new PersistentWorkerHelper(); + ) { + WorkerProtocol.WorkRequest.newBuilder() + .build() + .writeDelimitedTo(helper.requestOut); + + Worker.Interface worker = new Worker.Interface() { + @Override + public void work(String[] args) throws Exception { + assert(System.in.read() == -1); + } + }; + + helper.runWorker(worker); + + WorkerProtocol.WorkResponse response = + WorkerProtocol.WorkResponse.parseDelimitedFrom(helper.responseIn); + } + } + + /** A helper to manage IO when testing a persistent worker. */ + private final class PersistentWorkerHelper implements AutoCloseable { + + public final PipedInputStream workerIn; + public final PipedOutputStream requestOut; + + public final PipedOutputStream workerOut; + public final PipedInputStream responseIn; + + private final InputStream stdin; + private final PrintStream stdout; + private final PrintStream stderr; + + public PersistentWorkerHelper() throws IOException { + this.workerIn = new PipedInputStream(); + this.requestOut = new PipedOutputStream(workerIn); + this.workerOut = new PipedOutputStream(); + this.responseIn = new PipedInputStream(workerOut); + + this.stdin = System.in; + this.stdout = System.out; + this.stderr = System.err; + + System.setIn(this.workerIn); + System.setOut(new PrintStream(this.workerOut)); + } + + public void runWorker(Worker.Interface worker) throws Exception{ + // otherwise the worker will poll indefinitely + this.requestOut.close(); + Worker.workerMain(new String[]{"--persistent_worker"}, worker); + } + + public void close() { + try { + this.workerIn.close(); + } catch (IOException e) {} + try { + this.requestOut.close(); + } catch (IOException e) {} + try { + this.workerOut.close(); + } catch (IOException e) {} + try { + this.responseIn.close(); + } catch (IOException e) {} + + System.setIn(this.stdin); + System.setOut(this.stdout); + System.setErr(this.stderr); + } + } + private static void fill(ByteArrayOutputStream baos, int amount) { for (int i = 0; i < amount; i++) { baos.write(0); From fe63cf21547deb1f35df21a1950ea4306a8e8f3f Mon Sep 17 00:00:00 2001 From: Andy Scott Date: Tue, 28 Apr 2020 22:19:47 -0700 Subject: [PATCH 16/16] clean up test --- src/java/io/bazel/rulesscala/worker/WorkerTest.java | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/java/io/bazel/rulesscala/worker/WorkerTest.java b/src/java/io/bazel/rulesscala/worker/WorkerTest.java index e0cf7d4c3..53c3aa7d3 100644 --- a/src/java/io/bazel/rulesscala/worker/WorkerTest.java +++ b/src/java/io/bazel/rulesscala/worker/WorkerTest.java @@ -6,6 +6,7 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import java.util.concurrent.atomic.AtomicInteger; import java.io.ByteArrayOutputStream; import java.io.InputStream; import java.io.IOException; @@ -16,6 +17,7 @@ import java.lang.SecurityManager; import java.util.Scanner; + import com.google.devtools.build.lib.worker.WorkerProtocol; @RunWith(JUnit4.class) @@ -85,17 +87,16 @@ public void testPersistentWorkerNoStdin() throws Exception { .build() .writeDelimitedTo(helper.requestOut); + final AtomicInteger result = new AtomicInteger(); Worker.Interface worker = new Worker.Interface() { @Override public void work(String[] args) throws Exception { - assert(System.in.read() == -1); + result.set(System.in.read()); } }; helper.runWorker(worker); - - WorkerProtocol.WorkResponse response = - WorkerProtocol.WorkResponse.parseDelimitedFrom(helper.responseIn); + assert(result.get() == -1); } }