Skip to content

Commit

Permalink
Don't cancel aborted jobs immediately (#11375)
Browse files Browse the repository at this point in the history
* Don't cancel aborted jobs immediately

Rather than cancelling Futures that capture jobs' logic,
this change introduces a two-level system:

- interrupt all jobs softly via ThreadInterrupted at safepoints
- if safepoint is not executed within some time period or it is
  but the job is still not cancelled, trigger a hard-interrupt
  by cancelling the job explicitly, if possible

Closes #11084.

* Only cancel Future when you mean it

Soft-cancelling a future only to later call it with `mayInterrupt` set
to `true` has no effect in the latter case.
Changed the logic so that interrupting a Future will really enforce it.

Ocassionally some commands should not attempt to run soft cancellations
- we know they will re-execute the program.

* Replace Thread.sleep with Future.get

No while loops etc, it's much easier to reason about what is soft and
hard interrupt supposed to do.

* Better comments/logs

* nit

* PR review

* Make test more robust
  • Loading branch information
hubertp authored Nov 5, 2024
1 parent 988316f commit 35e5ed5
Show file tree
Hide file tree
Showing 26 changed files with 257 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import org.enso.languageserver.runtime.{
}
import org.enso.languageserver.util.UnhandledLogging

import java.util.UUID
import scala.concurrent.duration.FiniteDuration

/** A request handler for `executionContext/modifyVisualization` commands.
Expand Down Expand Up @@ -41,16 +42,23 @@ class ModifyVisualizationHandler(
)
val cancellable =
context.system.scheduler.scheduleOnce(timeout, self, RequestTimeout)
context.become(responseStage(id, sender(), cancellable))
context.become(
responseStage(id, params.visualizationId, sender(), cancellable)
)
}

private def responseStage(
id: Id,
visualizationID: UUID,
replyTo: ActorRef,
cancellable: Cancellable
): Receive = {
case RequestTimeout =>
logger.error("Request [{}] timed out.", id)
logger.error(
"Request [{}] timed out for visualization {}.",
id,
visualizationID
)
replyTo ! ResponseError(Some(id), Errors.RequestTimeout)
context.stop(self)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ private void setExecutionEnvironment(
if (!oldEnvironmentName.equals(executionEnvironment.name())) {
ctx.jobControlPlane()
.abortJobs(
contextId, "set execution environment to " + executionEnvironment.name());
contextId,
"set execution environment to " + executionEnvironment.name(),
false);
ctx.locking()
.withWriteCompilationLock(
this.getClass(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,15 @@ public interface JobControlPlane {
* Aborts jobs that relates to the specified execution context.
*
* @param contextId an identifier of a context
* @param reason reason for aborting job(s)
* @param softAbortFirst true if ongoing jobs should be aborted with safepoints first, even if
* marked as interruptible
* @param classOf abort jobs of a given class only. If empty all jobs for the given context are
* aborted
*/
@SuppressWarnings("unchecked")
void abortJobs(UUID contextId, String reason, Class<? extends Job<?>>... classOf);
void abortJobs(
UUID contextId, String reason, boolean softAbortFirst, Class<? extends Job<?>>... classOf);

/**
* Aborts jobs that relate to the specified execution context.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public ExecuteExpressionJob(
}

@Override
public Executable run(RuntimeContext ctx) {
public Executable runImpl(RuntimeContext ctx) {
return ctx.locking()
.withContextLock(
ctx.locking().getOrCreateContextLock(contextId),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public SerializeModuleJob(QualifiedName moduleName) {
}

@Override
public Void run(RuntimeContext ctx) {
public Void runImpl(RuntimeContext ctx) {
var ensoContext = ctx.executionService().getContext();
var compiler = ensoContext.getCompiler();
boolean useGlobalCacheLocations = ensoContext.isUseGlobalCache();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class DestroyContextCmd(
}

private def removeContext()(implicit ctx: RuntimeContext): Unit = {
ctx.jobControlPlane.abortJobs(request.contextId, "destroy context")
ctx.jobControlPlane.abortJobs(request.contextId, "destroy context", false)
val contextLock = ctx.locking.getOrCreateContextLock(request.contextId)
try {
ctx.locking.withContextLock(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@ class InterruptContextCmd(
): Future[Unit] =
if (doesContextExist) {
Future {
ctx.jobControlPlane.abortJobs(request.contextId, "interrupt context")
ctx.jobControlPlane.abortJobs(
request.contextId,
"interrupt context",
false
)
reply(Api.InterruptContextResponse(request.contextId))
}
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class PopContextCmd(
ec: ExecutionContext
): Future[Unit] =
Future {
ctx.jobControlPlane.abortJobs(request.contextId, "pop context")
ctx.jobControlPlane.abortJobs(request.contextId, "pop context", false)
val maybeTopItem = ctx.contextManager.pop(request.contextId)
if (maybeTopItem.isDefined) {
reply(Api.PopContextResponse(request.contextId))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class PushContextCmd(
ec: ExecutionContext
): Future[Boolean] =
Future {
ctx.jobControlPlane.abortJobs(request.contextId, "push context")
ctx.jobControlPlane.abortJobs(request.contextId, "push context", false)
val stack = ctx.contextManager.getStack(request.contextId)
val pushed = request.stackItem match {
case _: Api.StackItem.ExplicitCall if stack.isEmpty =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@ class RecomputeContextCmd(
ec: ExecutionContext
): Future[Boolean] = {
Future {
ctx.jobControlPlane.abortJobs(request.contextId, "recompute context")
ctx.jobControlPlane.abortJobs(
request.contextId,
"recompute context",
false
)
val stack = ctx.contextManager.getStack(request.contextId)
if (stack.isEmpty) {
reply(Api.EmptyStackError(request.contextId))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
package org.enso.interpreter.instrument.execution

import com.oracle.truffle.api.TruffleLogger
import org.enso.common.Asserts.assertInJvm
import org.enso.interpreter.instrument.InterpreterContext
import org.enso.interpreter.instrument.job.{BackgroundJob, Job, UniqueJob}
import org.enso.text.Sha3_224VersionCalculator

import java.util
import java.util.{Collections, UUID}
import java.util.concurrent.atomic.AtomicReference
import java.util.concurrent.ExecutorService
import java.util.concurrent.{ExecutorService, TimeUnit}
import java.util.logging.Level

import scala.concurrent.{Future, Promise}
import scala.concurrent.{Future, Promise, TimeoutException}
import scala.util.control.NonFatal

/** This component schedules the execution of jobs. It keeps a queue of
Expand All @@ -37,6 +37,9 @@ final class JobExecutionEngine(

private val context = interpreterContext.executionService.getContext

private val pendingCancellationsExecutor =
context.newFixedThreadPool(1, "pending-cancellations", false)

private val jobParallelism = context.getJobParallelism

private var isBackgroundJobsStarted = false
Expand Down Expand Up @@ -78,6 +81,82 @@ final class JobExecutionEngine(
private lazy val logger: TruffleLogger =
runtimeContext.executionService.getLogger

// Independent Runnable that has a list of jobs that should finish within a pre-determined window
// and, if not, are interrupted.
private class ForceJobCancellations(val pendingJobs: Seq[(Long, RunningJob)])
extends Runnable {
private val forceInterruptTimeout: Long = 50 * 1000

override def run(): Unit = {
pendingJobs.sortBy(_._1).foreach {
case (timeRequestedToCancel, runningJob) =>
try {
val now = System.currentTimeMillis()
val timeSinceRequestedToCancel = now - timeRequestedToCancel
assertInJvm(timeSinceRequestedToCancel > 0)
val timeToCancel =
forceInterruptTimeout - timeSinceRequestedToCancel
logger.log(
Level.FINEST,
"About to wait {}ms to cancel job {}",
Array[Any](
timeToCancel,
runningJob.id
)
)
runningJob.future.get(timeToCancel, TimeUnit.MILLISECONDS)
logger.log(
Level.FINEST,
"Job {} finished within the allocated soft-cancel time"
)
} catch {
case _: TimeoutException =>
val sb = new StringBuilder(
"Threaddump when timeout is reached while waiting for the job " + runningJob.id + " running in thread " + runningJob.job
.threadNameExecutingJob() + " to cancel:\n"
)
Thread.getAllStackTraces.entrySet.forEach { entry =>
sb.append(entry.getKey.getName).append("\n")
entry.getValue.foreach { e =>
sb.append(" ")
.append(e.getClassName)
.append(".")
.append(e.getMethodName)
.append("(")
.append(e.getFileName)
.append(":")
.append(e.getLineNumber)
.append(")\n")
}
}
logger.log(Level.WARNING, sb.toString())
runningJob.future.cancel(runningJob.job.mayInterruptIfRunning)
case e: Throwable =>
logger.log(
Level.WARNING,
"Encountered exception while waiting on status of pending jobs",
e
)
}
}
}
}

private def maybeForceCancelRunningJob(
runningJob: RunningJob,
softAbortFirst: Boolean
): Option[RunningJob] = {
val delayJobCancellation =
runningJob.job.mayInterruptIfRunning && softAbortFirst || !runningJob.job
.hasStarted()
if (delayJobCancellation) {
Some(runningJob)
} else {
runningJob.future.cancel(runningJob.job.mayInterruptIfRunning)
None
}
}

/** @inheritdoc */
override def runBackground[A](job: BackgroundJob[A]): Unit =
synchronized {
Expand Down Expand Up @@ -118,14 +197,36 @@ final class JobExecutionEngine(
case jobRef: UniqueJob[_] if jobRef.equalsTo(job) =>
logger
.log(Level.FINEST, s"Cancelling duplicate job [$jobRef].")
runningJob.future.cancel(jobRef.mayInterruptIfRunning)
updatePendingCancellations(
maybeForceCancelRunningJob(
runningJob,
softAbortFirst = true
).toSeq
)
case _ =>
}
}
case _ =>
}
}

private def updatePendingCancellations(
jobsToCancel: Seq[RunningJob]
): Unit = {
val at = System.currentTimeMillis()
if (jobsToCancel.nonEmpty) {
logger.log(
Level.FINEST,
"Submitting {0} job(s) for future cancellation",
jobsToCancel.map(j => (j.job.getClass, j.id))
)
}
if (jobsToCancel.nonEmpty)
pendingCancellationsExecutor.submit(
new ForceJobCancellations(jobsToCancel.map((at, _)))
)
}

private def runInternal[A](
job: Job[A],
executorService: ExecutorService,
Expand Down Expand Up @@ -164,6 +265,7 @@ final class JobExecutionEngine(
)
}
})
job.setJobId(jobId)
val runningJob = RunningJob(jobId, job, future)

val queue = runningJobsRef.updateAndGet(_ :+ runningJob)
Expand Down Expand Up @@ -192,9 +294,11 @@ final class JobExecutionEngine(
"Aborting {0} jobs because {1}: {2}",
Array[Any](cancellableJobs.length, reason, cancellableJobs.map(_.id))
)
cancellableJobs.foreach { runningJob =>
runningJob.future.cancel(runningJob.job.mayInterruptIfRunning)
}

val pending = cancellableJobs.flatMap(
maybeForceCancelRunningJob(_, softAbortFirst = true)
)
updatePendingCancellations(pending)
runtimeContext.executionService.getContext.getThreadManager
.interruptThreads()
}
Expand All @@ -203,23 +307,27 @@ final class JobExecutionEngine(
override def abortJobs(
contextId: UUID,
reason: String,
softAbortFirst: Boolean,
toAbort: Class[_ <: Job[_]]*
): Unit = {
val allJobs = runningJobsRef.get()
val contextJobs = allJobs.filter(_.job.contextIds.contains(contextId))
contextJobs.foreach { runningJob =>
if (
runningJob.job.isCancellable && (toAbort.isEmpty || toAbort
.contains(runningJob.getClass))
) {
logger.log(
Level.FINE,
"Aborting job {0} because {1}",
Array[Any](runningJob.id, reason)
)
runningJob.future.cancel(runningJob.job.mayInterruptIfRunning)
val pending = contextJobs
.flatMap { runningJob =>
if (
runningJob.job.isCancellable && (toAbort.isEmpty || toAbort
.contains(runningJob.getClass))
) {
logger.log(
Level.FINE,
"Aborting job {0} because {1}",
Array[Any](runningJob.id, reason)
)
Some(runningJob)
} else None
}
}
.flatMap(maybeForceCancelRunningJob(_, softAbortFirst))
updatePendingCancellations(pending)
runtimeContext.executionService.getContext.getThreadManager
.interruptThreads()
}
Expand All @@ -232,16 +340,19 @@ final class JobExecutionEngine(
): Unit = {
val allJobs = runningJobsRef.get()
val contextJobs = allJobs.filter(_.job.contextIds.contains(contextId))
contextJobs.foreach { runningJob =>
if (runningJob.job.isCancellable && accept.apply(runningJob.job)) {
logger.log(
Level.FINE,
"Aborting job {0} because {1}",
Array[Any](runningJob.id, reason)
)
runningJob.future.cancel(runningJob.job.mayInterruptIfRunning)
val pending = contextJobs
.flatMap { runningJob =>
if (runningJob.job.isCancellable && accept.apply(runningJob.job)) {
logger.log(
Level.FINE,
"Aborting job {0} because {1}",
Array[Any](runningJob.id, reason)
)
Some(runningJob)
} else None
}
}
.flatMap(maybeForceCancelRunningJob(_, softAbortFirst = true))
updatePendingCancellations(pending)
runtimeContext.executionService.getContext.getThreadManager
.interruptThreads()
}
Expand All @@ -262,9 +373,10 @@ final class JobExecutionEngine(
"Aborting {0} background jobs because {1}: {2}",
Array[Any](cancellableJobs.length, reason, cancellableJobs.map(_.id))
)
cancellableJobs.foreach { runningJob =>
runningJob.future.cancel(runningJob.job.mayInterruptIfRunning)
}
val pending = cancellableJobs.flatMap(
maybeForceCancelRunningJob(_, softAbortFirst = true)
)
updatePendingCancellations(pending)
}

/** @inheritdoc */
Expand Down Expand Up @@ -292,6 +404,7 @@ final class JobExecutionEngine(
.interruptThreads()
jobExecutor.shutdownNow()
backgroundJobExecutor.shutdownNow()
pendingCancellationsExecutor.shutdownNow()
}

/** Submit background jobs preserving the stable order. */
Expand Down
Loading

0 comments on commit 35e5ed5

Please sign in to comment.