Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,50 +28,67 @@ import org.apache.spark.SparkFunSuite

class UninterruptibleThreadSuite extends SparkFunSuite {

/** Sleep millis and return true if it's interrupted */
/* Sleep millis and return true if it's interrupted */
private def sleep(millis: Long): Boolean = {
try {
Thread.sleep(millis)
false
} catch {
case _: InterruptedException =>
case e: InterruptedException =>
log.error("Thread interrupted during sleep", e)
true
}
}

/* Await latch and return true if it's interrupted */
private def await(latch: CountDownLatch, timeout: Long = 10,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the major difference after using this await()?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can equally use sleep() and await() to test for InterruptedException as they both throw an exception in case thread is interrupted. The difference is that sleep() does not return while await() will exit once the main thread calls interrupt() and count down the latch.

timeUnit: TimeUnit = TimeUnit.SECONDS): Boolean = {
try {
if (!latch.await(timeout, timeUnit)) {
log.error("timeout while waiting for the latch")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's test, and this log seems not useful as we fail right after it with the same error message.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When await is used not in the main test thread, fail() does not cause thread to fail, it terminates the thread with TestFailedException (that is logged to stderr). And when this happens test fails as other conditions are not met. The log.error() logs to unit-tests.log, so there is no duplication of the error message. Note that in this case, the code is similar to:

log.error("message")
throw new Exception("message")

fail("timeout while waiting for the latch")
}
false
} catch {
case e: InterruptedException =>
log.error("Thread interrupted during await", e)
true
}
}

test("interrupt when runUninterruptibly is running") {
val enterRunUninterruptibly = new CountDownLatch(1)
val interruptLatch = new CountDownLatch(1)
@volatile var hasInterruptedException = false
@volatile var interruptStatusBeforeExit = false
val t = new UninterruptibleThread("test") {
val t = new UninterruptibleThread("runUninterruptibly") {
override def run(): Unit = {
runUninterruptibly {
enterRunUninterruptibly.countDown()
hasInterruptedException = sleep(1000)
hasInterruptedException = await(interruptLatch)
}
interruptStatusBeforeExit = Thread.interrupted()
}
}
t.start()
assert(enterRunUninterruptibly.await(10, TimeUnit.SECONDS), "await timeout")
assert(!await(enterRunUninterruptibly), "await interrupted")
t.interrupt()
interruptLatch.countDown()
t.join()
assert(hasInterruptedException === false)
assert(interruptStatusBeforeExit)
assert(!hasInterruptedException, "runUninterruptibly should not be interrupted")
assert(interruptStatusBeforeExit, "interrupt flag should be set")
}

test("interrupt before runUninterruptibly runs") {
val interruptLatch = new CountDownLatch(1)
@volatile var hasInterruptedException = false
@volatile var interruptStatusBeforeExit = false
val t = new UninterruptibleThread("test") {
val t = new UninterruptibleThread("runUninterruptibly") {
override def run(): Unit = {
Uninterruptibles.awaitUninterruptibly(interruptLatch, 10, TimeUnit.SECONDS)
try {
runUninterruptibly {
}
} catch {
case _: InterruptedException => hasInterruptedException = true
assert(Uninterruptibles.awaitUninterruptibly(interruptLatch, 10, TimeUnit.SECONDS))
assert(isInterrupted, "interrupt flag should be set")
runUninterruptibly {
hasInterruptedException = sleep(0)
}
interruptStatusBeforeExit = Thread.interrupted()
}
Expand All @@ -80,130 +97,128 @@ class UninterruptibleThreadSuite extends SparkFunSuite {
t.interrupt()
interruptLatch.countDown()
t.join()
assert(hasInterruptedException === false)
assert(interruptStatusBeforeExit)
assert(!hasInterruptedException, "runUninterruptibly should not be interrupted")
assert(interruptStatusBeforeExit, "interrupt flag should be set")
}

test("nested runUninterruptibly") {
val enterRunUninterruptibly = new CountDownLatch(1)
val interruptLatch = new CountDownLatch(1)
@volatile var hasInterruptedException = false
@volatile var interruptStatusBeforeExit = false
val t = new UninterruptibleThread("test") {
val t = new UninterruptibleThread("runUninterruptibly") {
override def run(): Unit = {
runUninterruptibly {
enterRunUninterruptibly.countDown()
Uninterruptibles.awaitUninterruptibly(interruptLatch, 10, TimeUnit.SECONDS)
hasInterruptedException = sleep(1)
runUninterruptibly {
if (sleep(1)) {
hasInterruptedException = true
hasInterruptedException = await(interruptLatch)
if (!hasInterruptedException) {
runUninterruptibly {
hasInterruptedException = sleep(0)
}
}
if (sleep(1)) {
hasInterruptedException = true
hasInterruptedException |= sleep(0)
}
}
interruptStatusBeforeExit = Thread.interrupted()
}
}
t.start()
assert(enterRunUninterruptibly.await(10, TimeUnit.SECONDS), "await timeout")
assert(!await(enterRunUninterruptibly), "await interrupted")
t.interrupt()
interruptLatch.countDown()
assert(!sleep(0), "sleep should not be interrupted")
t.interrupt()
assert(!sleep(0), "sleep should not be interrupted")
t.interrupt()
t.join()
assert(hasInterruptedException === false)
assert(interruptStatusBeforeExit)
assert(!hasInterruptedException, "runUninterruptibly should not be interrupted")
assert(interruptStatusBeforeExit, "interrupt flag should be set")
}

test("no runUninterruptibly") {
@volatile var hasInterruptedException = false
val t = new UninterruptibleThread("test") {
@volatile var interruptStatusBeforeExit = false
val t = new UninterruptibleThread("run") {
override def run(): Unit = {
if (sleep(0)) {
hasInterruptedException = true
}
hasInterruptedException = sleep(0)
interruptStatusBeforeExit = Thread.interrupted()
}
}
t.interrupt()
t.start()
t.join()
assert(hasInterruptedException === true)
assert(hasInterruptedException, "run should be interrupted")
assert(!interruptStatusBeforeExit, "interrupt flag should not be set")
}

test("SPARK-51821 uninterruptibleLock deadlock") {
val latch = new CountDownLatch(1)
val task = new UninterruptibleThread("task thread") {
val interruptLatch = new CountDownLatch(1)
val t = new UninterruptibleThread("run") {
override def run(): Unit = {
val channel = new AbstractInterruptibleChannel() {
override def implCloseChannel(): Unit = {
begin()
latch.countDown()
interruptLatch.countDown()
try {
Thread.sleep(Long.MaxValue)
} catch {
case _: InterruptedException => Thread.currentThread().interrupt()
case e: InterruptedException =>
log.info("sleep interrupted", e)
Thread.currentThread().interrupt()
}
}
}
channel.close()
}
}
task.start()
assert(latch.await(10, TimeUnit.SECONDS), "await timeout")
task.interrupt()
task.join()
t.start()
assert(!await(interruptLatch), "await interrupted")
t.interrupt()
t.join()
}

test("stress test") {
for (i <- 0 until 20) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the major change tries to improve the test coverage? Since it's executed in sequence, the stress of interruptions to a single UninterruptibleThread doesn't seem to be inscreased to me.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change does not target to increase concurrency of the stress test. It targets to reproduce SPARK-53394

stressTest(i)
}
}

def stressTest(i: Int): Unit = {
@volatile var hasInterruptedException = false
val t = new UninterruptibleThread("test") {
val t = new UninterruptibleThread(s"stress test $i") {
override def run(): Unit = {
for (i <- 0 until 100) {
for (i <- 0 until 100 if !hasInterruptedException) {
try {
runUninterruptibly {
if (sleep(Random.nextInt(10))) {
hasInterruptedException = true
}
hasInterruptedException = sleep(Random.nextInt(10))
runUninterruptibly {
if (sleep(Random.nextInt(10))) {
hasInterruptedException = true
}
}
if (sleep(Random.nextInt(10))) {
hasInterruptedException = true
hasInterruptedException |= sleep(Random.nextInt(10))
}
hasInterruptedException |= sleep(Random.nextInt(10))
}
Uninterruptibles.sleepUninterruptibly(Random.nextInt(10), TimeUnit.MILLISECONDS)
// 50% chance to clear the interrupted status
if (Random.nextBoolean()) {
Thread.interrupted()
}
} catch {
case _: InterruptedException =>
// The first runUninterruptibly may throw InterruptedException if the interrupt status
// is set before running `f`.
case _: InterruptedException => hasInterruptedException = true
}
}
}
}
t.start()
val threads = new Array[Thread](10)
for (j <- 0 until 10) {
threads(j) = new Thread() {
val threads = Array.fill(Runtime.getRuntime.availableProcessors)(
new Thread() {
override def run(): Unit = {
for (i <- 0 until 400) {
for (i <- 0 until 400 if !hasInterruptedException) {
Thread.sleep(Random.nextInt(10))
t.interrupt()
}
}
}
threads(j).start()
}
for (j <- 0 until 10) {
threads(j).join()
}
})
t.start()
threads.foreach(t => t.start())
threads.foreach(t => t.join())
t.join()
assert(hasInterruptedException === false)
assert(!hasInterruptedException, "runUninterruptibly should not be interrupted")
}
}