diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 66e204fee44b9..296d270a2fafc 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -325,10 +325,10 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( // SPARK-35009: avoid creating multiple monitor threads for the same python worker // and task context if (PythonRunner.runningMonitorThreads.add(key)) { - new MonitorThread(SparkEnv.get, worker, context).start() + new MonitorThread(SparkEnv.get, worker, context, releasedOrClosed).start() } } else { - new MonitorThread(SparkEnv.get, worker, context).start() + new MonitorThread(SparkEnv.get, worker, context, releasedOrClosed).start() } // Return an iterator that read lines from the process's stdout @@ -685,7 +685,11 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( * interrupts disabled. In that case we will need to explicitly kill the worker, otherwise the * threads can block indefinitely. */ - class MonitorThread(env: SparkEnv, worker: PythonWorker, context: TaskContext) + class MonitorThread( + env: SparkEnv, + worker: PythonWorker, + context: TaskContext, + releasedOrClosed: AtomicBoolean) extends Thread(s"Worker Monitor for $pythonExec") { /** How long to wait before killing the python worker if a task cannot be interrupted. */ @@ -701,7 +705,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( } if (!context.isCompleted()) { Thread.sleep(taskKillTimeout) - if (!context.isCompleted()) { + if (!context.isCompleted() && releasedOrClosed.compareAndSet(false, true)) { try { // Mimic the task name used in `Executor` to help the user find out the task to blame. val taskName = s"${context.partitionId()}.${context.attemptNumber()} " +