Skip to content

Commit

Permalink
Refactor concurrency handling code
Browse files Browse the repository at this point in the history
  • Loading branch information
lukebemish committed Dec 11, 2024
1 parent d8517ed commit 46805de
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 89 deletions.
183 changes: 121 additions & 62 deletions src/main/java/dev/lukebemish/immaculate/ForkFormatter.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@
import java.io.BufferedReader;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.UncheckedIOException;
import java.net.InetAddress;
import java.net.Socket;
import java.nio.charset.StandardCharsets;
import java.net.SocketException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
Expand All @@ -24,7 +23,6 @@

public class ForkFormatter implements FileFormatter {
private final Process process;
private final Socket socket;
private final ResultListener listener;

public ForkFormatter(ForkFormatterSpec spec) {
Expand All @@ -35,6 +33,9 @@ public ForkFormatter(ForkFormatterSpec spec) {
List<String> args = new ArrayList<>();
args.add(spec.getJavaLauncher().get().getExecutablePath().getAsFile().toString());
args.addAll(spec.getJvmArgs().get());
if (spec.getHideStacktrace().get()) {
args.add("-Ddev.lukebemish.immaculate.wrapper.hidestacktrace=true");
}
args.addAll(List.of(
"-cp",
spec.getClasspath().getAsPath(),
Expand Down Expand Up @@ -66,9 +67,7 @@ public ForkFormatter(ForkFormatterSpec spec) {
try {
String socketPortString = socketPort.get(4000, TimeUnit.MILLISECONDS);
int port = Integer.parseInt(socketPortString);
this.socket = new Socket(InetAddress.getLoopbackAddress(), port);

this.listener = new ResultListener(socket);
this.listener = new ResultListener(new Socket(InetAddress.getLoopbackAddress(), port));
this.listener.start();
} catch (InterruptedException | ExecutionException | TimeoutException | IOException e) {
throw new RuntimeException(e);
Expand Down Expand Up @@ -103,83 +102,156 @@ public void run() {
}
}

private static final class SocketHandle {
private final DataOutputStream output;
private final DataInputStream input;
private final Socket socket;

private SocketHandle(Socket socket) throws IOException {
this.output = new DataOutputStream(socket.getOutputStream());
this.input = new DataInputStream(socket.getInputStream());
this.socket = socket;
}

synchronized void writeSubmission(int id, String fileName, String text) throws IOException {
output.writeInt(id);
output.writeUTF(fileName);
output.writeUTF(text);
output.flush();
}

// Will be true only if a shutdown signal is successfully sent to the channel.
private volatile boolean gracefulShutdown = false;

synchronized void shutdown() throws IOException {
try {
// -1 ID signals the end of submissions
output.writeInt(-1);
output.flush();
this.gracefulShutdown = true;
} finally {
// Then close the socket
socket.close();
}
}

int readId() throws IOException {
try {
return input.readInt();
} catch (SocketException e) {
// Could be the socket is intentionally closed during cleanup, could be something went sideways.
// To differentiate -- check gracefulShutdown
if (gracefulShutdown) {
return -1;
}
throw e;
}
}

boolean readSuccess() throws IOException {
return input.readBoolean();
}

String readResult() throws IOException {
return input.readUTF();
}
}

private static final class ResultListener extends Thread {
private final Map<Integer, CompletableFuture<String>> results = new ConcurrentHashMap<>();
private final Socket socket;
private final DataOutputStream output;
private final SocketHandle socketHandle;
// Handle uncaught exceptions by re-throwing them on shutdown
private volatile Throwable thrownException;

private ResultListener(Socket socket) throws IOException {
this.socket = socket;
output = new DataOutputStream(socket.getOutputStream());
this.socketHandle = new SocketHandle(socket);
this.setUncaughtExceptionHandler((t, e) -> {
try {
shutdown(e);
thrownException = e;
} catch (IOException ex) {
var exception = new UncheckedIOException(ex);
exception.addSuppressed(e);
ResultListener.this.getThreadGroup().uncaughtException(t, exception);
thrownException = exception;
}
ResultListener.this.getThreadGroup().uncaughtException(t, e);
});
}

public synchronized CompletableFuture<String> submit(int id, String fileName, String text) throws IOException {
// Non-blocking, returns a future that will complete when the result is available (or throws if the listener is closed early unexpectedly)
public CompletableFuture<String> submit(int id, String fileName, String text) throws IOException {
if (closed.get()) {
throw new IOException("Listener is closed");
return CompletableFuture.failedFuture(new IOException("Listener is closed"));
}
var out = results.computeIfAbsent(id, i -> new CompletableFuture<>());
output.writeInt(id);
byte[] fileNameBytes = fileName.getBytes(StandardCharsets.UTF_8);
byte[] textBytes = text.getBytes(StandardCharsets.UTF_8);
output.writeInt(fileNameBytes.length);
output.write(fileNameBytes);
output.writeInt(textBytes.length);
output.write(textBytes);
output.flush();
// Submissions to the child process take the format ID, file name, file contents -- the ID lets the result be matched up
socketHandle.writeSubmission(id, fileName, text);
return out;
}

private final AtomicBoolean closed = new AtomicBoolean();

public void shutdown() throws IOException {
// Blocks until proper thread shutdown
public void ensureShutdown() throws Throwable {
/*
Cleans up the child process, stops the listener thread, and joins the thread ensuring it is closed, rethrowing
exceptions as necessary.
*/
shutdown(new IOException("Execution was interrupted"));

try {
this.join();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}

if (thrownException != null) {
throw thrownException;
}
}

// Non-blocking
private void shutdown(Throwable t) throws IOException {
/*
This method handles graceful shutdown of the child process and forcing the listener thread to stop. It does
not ensure the listener thread is closed.
- ensure the shutdown logic runs exactly once, in the proper order; it is possible for logic running on the
thread to request a shutdown during the shutdown process initialized from another thread.
- prevent submission of new tasks
- complete all pending tasks exceptionally (with the provided exception if one is given)
- stop the child process (by sending it a "shutdown" signal with ID -1)
- stop the thread if it is running. The thread could be waiting at a number of places. Either:
- the readId() call, if everything is running normally
- the readResult() or readSuccess() call, if something is going badly wrong in the child process
- not waiting, just in the loop -- the "closed" flag will be checked at the top of the loop
to stop in either of these cases, we simply close the socket; this results in anything blocking on reading
from the socket throwing an exception (see Socket#close()).
*/

// Prevent multiple concurrent shutdowns
if (!this.closed.compareAndSet(false, true)) return;

for (var future : results.values()) {
future.completeExceptionally(t);
}
results.clear();

socket.shutdownInput();

if (Thread.currentThread() != this) {
try {
this.join();
} catch (InterruptedException e) {
// continue, it's fine
}
}

output.writeInt(-1);
output.flush();
socket.close();
socketHandle.shutdown();
}

@Override
public void run() {
try {
if (!closed.get()) {
var input = new DataInputStream(socket.getInputStream());
while (!closed.get()) {
int id = input.readInt();
boolean success = input.readBoolean();
int id = socketHandle.readId();
if (id == -1) {
// The child process has been sent a shutdown signal gracefully
shutdown(new IOException("Listener is closed"));
break;
}
boolean success = socketHandle.readSuccess();
if (success) {
int length = input.readInt();
byte[] bytes = input.readNBytes(length);
String result = new String(bytes, StandardCharsets.UTF_8);
String result = socketHandle.readResult();
var future = results.remove(id);
if (future != null) {
future.complete(result);
Expand All @@ -193,41 +265,28 @@ public void run() {
}
}
}
} catch (EOFException e) {
try {
shutdown(e);
} catch (IOException ex) {
throw new UncheckedIOException(ex);
}
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}
}

@Override
public synchronized void close() {
List<Exception> suppressed = new ArrayList<>();
public void close() {
List<Throwable> suppressed = new ArrayList<>();
if (listener != null) {
try {
listener.shutdown();
} catch (Exception e) {
suppressed.add(e);
}
}
if (socket != null) {
try {
socket.close();
} catch (Exception e) {
suppressed.add(e);
listener.ensureShutdown();
} catch (Throwable t) {
suppressed.add(t);
}
}
if (process != null) {
try {
process.destroy();
process.waitFor();
} catch (Exception e) {
suppressed.add(e);
} catch (Throwable t) {
suppressed.add(t);
}
}
if (!suppressed.isEmpty()) {
Expand Down
65 changes: 38 additions & 27 deletions wrapper/src/main/java/dev/lukebemish/immaculate/wrapper/Main.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import java.io.IOException;
import java.lang.reflect.Constructor;
import java.net.ServerSocket;
import java.nio.charset.StandardCharsets;
import java.net.Socket;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -57,29 +57,30 @@ private void run() throws IOException {
// This tells the parent process what port we're listening on
System.out.println(socket.getLocalPort());
var socket = this.socket.accept();
var input = new DataInputStream(socket.getInputStream());
var os = new DataOutputStream(socket.getOutputStream());
var output = new Output(os);
// Communication back to the parent is done through this handle, which ensures synchronization on the output stream.
var socketHandle = new SocketHandle(socket);
while (true) {
int id = input.readInt();
int id = socketHandle.readId();
if (id == -1) {
// We have been sent a signal to gracefully shutdown, so we stop processing new submissions
break;
}
String fileName = new String(input.readNBytes(input.readInt()), StandardCharsets.UTF_8);
String text = new String(input.readNBytes(input.readInt()), StandardCharsets.UTF_8);
execute(id, fileName, text, output);
String fileName = socketHandle.readUTF();
String text = socketHandle.readUTF();
// Submissions to the child process take the format ID, file name, file contents
execute(id, fileName, text, socketHandle);
}
}

private void execute(int id, String fileName, String text, Output output) {
private void execute(int id, String fileName, String text, SocketHandle socketHandle) {
executor.submit(() -> {
try {
String result = wrapper.format(fileName, text);
output.writeSuccess(id, result);
socketHandle.writeSuccess(id, result);
} catch (Throwable t) {
logException(t);
try {
output.writeFailure(id);
socketHandle.writeFailure(id);
} catch (IOException e) {
throw new RuntimeException(e);
}
Expand All @@ -96,24 +97,34 @@ private static void logException(Throwable t) {
}
}

private record Output(DataOutputStream output) {
void writeFailure(int id) throws IOException {
synchronized (this) {
output.writeInt(id);
output.writeBoolean(false);
output.flush();
}
private static final class SocketHandle {
private final DataOutputStream output;
private final DataInputStream input;

private SocketHandle(Socket socket) throws IOException {
this.output = new DataOutputStream(socket.getOutputStream());
this.input = new DataInputStream(socket.getInputStream());
}

void writeSuccess(int id, String result) throws IOException {
synchronized (this) {
output.writeInt(id);
output.writeBoolean(true);
byte[] bytes = result.getBytes(StandardCharsets.UTF_8);
output.writeInt(bytes.length);
output.write(bytes);
output.flush();
}
synchronized void writeFailure(int id) throws IOException {
output.writeInt(id);
output.writeBoolean(false);
output.flush();
}

synchronized void writeSuccess(int id, String result) throws IOException {
output.writeInt(id);
output.writeBoolean(true);
output.writeUTF(result);
output.flush();
}

int readId() throws IOException {
return input.readInt();
}

String readUTF() throws IOException {
return input.readUTF();
}
}
}

0 comments on commit 46805de

Please sign in to comment.