diff --git a/core/src/test/scala/org/apache/spark/util/UninterruptibleThreadSuite.scala b/core/src/test/scala/org/apache/spark/util/UninterruptibleThreadSuite.scala index fbc954d05af82..aa6fadea8a302 100644 --- a/core/src/test/scala/org/apache/spark/util/UninterruptibleThreadSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UninterruptibleThreadSuite.scala @@ -28,50 +28,67 @@ import org.apache.spark.SparkFunSuite class UninterruptibleThreadSuite extends SparkFunSuite { - /** Sleep millis and return true if it's interrupted */ + /* Sleep millis and return true if it's interrupted */ private def sleep(millis: Long): Boolean = { try { Thread.sleep(millis) false } catch { - case _: InterruptedException => + case e: InterruptedException => + log.error("Thread interrupted during sleep", e) + true + } + } + + /* Await latch and return true if it's interrupted */ + private def await(latch: CountDownLatch, timeout: Long = 10, + timeUnit: TimeUnit = TimeUnit.SECONDS): Boolean = { + try { + if (!latch.await(timeout, timeUnit)) { + log.error("timeout while waiting for the latch") + fail("timeout while waiting for the latch") + } + false + } catch { + case e: InterruptedException => + log.error("Thread interrupted during await", e) true } } test("interrupt when runUninterruptibly is running") { val enterRunUninterruptibly = new CountDownLatch(1) + val interruptLatch = new CountDownLatch(1) @volatile var hasInterruptedException = false @volatile var interruptStatusBeforeExit = false - val t = new UninterruptibleThread("test") { + val t = new UninterruptibleThread("runUninterruptibly") { override def run(): Unit = { runUninterruptibly { enterRunUninterruptibly.countDown() - hasInterruptedException = sleep(1000) + hasInterruptedException = await(interruptLatch) } interruptStatusBeforeExit = Thread.interrupted() } } t.start() - assert(enterRunUninterruptibly.await(10, TimeUnit.SECONDS), "await timeout") + assert(!await(enterRunUninterruptibly), "await interrupted") t.interrupt() + interruptLatch.countDown() t.join() - assert(hasInterruptedException === false) - assert(interruptStatusBeforeExit) + assert(!hasInterruptedException, "runUninterruptibly should not be interrupted") + assert(interruptStatusBeforeExit, "interrupt flag should be set") } test("interrupt before runUninterruptibly runs") { val interruptLatch = new CountDownLatch(1) @volatile var hasInterruptedException = false @volatile var interruptStatusBeforeExit = false - val t = new UninterruptibleThread("test") { + val t = new UninterruptibleThread("runUninterruptibly") { override def run(): Unit = { - Uninterruptibles.awaitUninterruptibly(interruptLatch, 10, TimeUnit.SECONDS) - try { - runUninterruptibly { - } - } catch { - case _: InterruptedException => hasInterruptedException = true + assert(Uninterruptibles.awaitUninterruptibly(interruptLatch, 10, TimeUnit.SECONDS)) + assert(isInterrupted, "interrupt flag should be set") + runUninterruptibly { + hasInterruptedException = sleep(0) } interruptStatusBeforeExit = Thread.interrupted() } @@ -80,8 +97,8 @@ class UninterruptibleThreadSuite extends SparkFunSuite { t.interrupt() interruptLatch.countDown() t.join() - assert(hasInterruptedException === false) - assert(interruptStatusBeforeExit) + assert(!hasInterruptedException, "runUninterruptibly should not be interrupted") + assert(interruptStatusBeforeExit, "interrupt flag should be set") } test("nested runUninterruptibly") { @@ -89,90 +106,94 @@ class UninterruptibleThreadSuite extends SparkFunSuite { val interruptLatch = new CountDownLatch(1) @volatile var hasInterruptedException = false @volatile var interruptStatusBeforeExit = false - val t = new UninterruptibleThread("test") { + val t = new UninterruptibleThread("runUninterruptibly") { override def run(): Unit = { runUninterruptibly { enterRunUninterruptibly.countDown() - Uninterruptibles.awaitUninterruptibly(interruptLatch, 10, TimeUnit.SECONDS) - hasInterruptedException = sleep(1) - runUninterruptibly { - if (sleep(1)) { - hasInterruptedException = true + hasInterruptedException = await(interruptLatch) + if (!hasInterruptedException) { + runUninterruptibly { + hasInterruptedException = sleep(0) } - } - if (sleep(1)) { - hasInterruptedException = true + hasInterruptedException |= sleep(0) } } interruptStatusBeforeExit = Thread.interrupted() } } t.start() - assert(enterRunUninterruptibly.await(10, TimeUnit.SECONDS), "await timeout") + assert(!await(enterRunUninterruptibly), "await interrupted") t.interrupt() interruptLatch.countDown() + assert(!sleep(0), "sleep should not be interrupted") + t.interrupt() + assert(!sleep(0), "sleep should not be interrupted") + t.interrupt() t.join() - assert(hasInterruptedException === false) - assert(interruptStatusBeforeExit) + assert(!hasInterruptedException, "runUninterruptibly should not be interrupted") + assert(interruptStatusBeforeExit, "interrupt flag should be set") } test("no runUninterruptibly") { @volatile var hasInterruptedException = false - val t = new UninterruptibleThread("test") { + @volatile var interruptStatusBeforeExit = false + val t = new UninterruptibleThread("run") { override def run(): Unit = { - if (sleep(0)) { - hasInterruptedException = true - } + hasInterruptedException = sleep(0) + interruptStatusBeforeExit = Thread.interrupted() } } t.interrupt() t.start() t.join() - assert(hasInterruptedException === true) + assert(hasInterruptedException, "run should be interrupted") + assert(!interruptStatusBeforeExit, "interrupt flag should not be set") } test("SPARK-51821 uninterruptibleLock deadlock") { - val latch = new CountDownLatch(1) - val task = new UninterruptibleThread("task thread") { + val interruptLatch = new CountDownLatch(1) + val t = new UninterruptibleThread("run") { override def run(): Unit = { val channel = new AbstractInterruptibleChannel() { override def implCloseChannel(): Unit = { begin() - latch.countDown() + interruptLatch.countDown() try { Thread.sleep(Long.MaxValue) } catch { - case _: InterruptedException => Thread.currentThread().interrupt() + case e: InterruptedException => + log.info("sleep interrupted", e) + Thread.currentThread().interrupt() } } } channel.close() } } - task.start() - assert(latch.await(10, TimeUnit.SECONDS), "await timeout") - task.interrupt() - task.join() + t.start() + assert(!await(interruptLatch), "await interrupted") + t.interrupt() + t.join() } test("stress test") { + for (i <- 0 until 20) { + stressTest(i) + } + } + + def stressTest(i: Int): Unit = { @volatile var hasInterruptedException = false - val t = new UninterruptibleThread("test") { + val t = new UninterruptibleThread(s"stress test $i") { override def run(): Unit = { - for (i <- 0 until 100) { + for (i <- 0 until 100 if !hasInterruptedException) { try { runUninterruptibly { - if (sleep(Random.nextInt(10))) { - hasInterruptedException = true - } + hasInterruptedException = sleep(Random.nextInt(10)) runUninterruptibly { - if (sleep(Random.nextInt(10))) { - hasInterruptedException = true - } - } - if (sleep(Random.nextInt(10))) { - hasInterruptedException = true + hasInterruptedException |= sleep(Random.nextInt(10)) } + hasInterruptedException |= sleep(Random.nextInt(10)) } Uninterruptibles.sleepUninterruptibly(Random.nextInt(10), TimeUnit.MILLISECONDS) // 50% chance to clear the interrupted status @@ -180,30 +201,24 @@ class UninterruptibleThreadSuite extends SparkFunSuite { Thread.interrupted() } } catch { - case _: InterruptedException => - // The first runUninterruptibly may throw InterruptedException if the interrupt status - // is set before running `f`. + case _: InterruptedException => hasInterruptedException = true } } } } - t.start() - val threads = new Array[Thread](10) - for (j <- 0 until 10) { - threads(j) = new Thread() { + val threads = Array.fill(Runtime.getRuntime.availableProcessors)( + new Thread() { override def run(): Unit = { - for (i <- 0 until 400) { + for (i <- 0 until 400 if !hasInterruptedException) { Thread.sleep(Random.nextInt(10)) t.interrupt() } } - } - threads(j).start() - } - for (j <- 0 until 10) { - threads(j).join() - } + }) + t.start() + threads.foreach(t => t.start()) + threads.foreach(t => t.join()) t.join() - assert(hasInterruptedException === false) + assert(!hasInterruptedException, "runUninterruptibly should not be interrupted") } }