-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-53622][CORE][TEST] Improve UninterruptibleThreadSuite
#52373
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,50 +28,67 @@ import org.apache.spark.SparkFunSuite | |
|
||
class UninterruptibleThreadSuite extends SparkFunSuite { | ||
|
||
/** Sleep millis and return true if it's interrupted */ | ||
/* Sleep millis and return true if it's interrupted */ | ||
private def sleep(millis: Long): Boolean = { | ||
try { | ||
Thread.sleep(millis) | ||
false | ||
} catch { | ||
case _: InterruptedException => | ||
case e: InterruptedException => | ||
log.error("Thread interrupted during sleep", e) | ||
true | ||
} | ||
} | ||
|
||
/* Await latch and return true if it's interrupted */ | ||
private def await(latch: CountDownLatch, timeout: Long = 10, | ||
timeUnit: TimeUnit = TimeUnit.SECONDS): Boolean = { | ||
try { | ||
if (!latch.await(timeout, timeUnit)) { | ||
log.error("timeout while waiting for the latch") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's test, and this log seems not useful as we fail right after it with the same error message. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When
|
||
fail("timeout while waiting for the latch") | ||
} | ||
false | ||
} catch { | ||
case e: InterruptedException => | ||
log.error("Thread interrupted during await", e) | ||
true | ||
} | ||
} | ||
|
||
test("interrupt when runUninterruptibly is running") { | ||
val enterRunUninterruptibly = new CountDownLatch(1) | ||
val interruptLatch = new CountDownLatch(1) | ||
@volatile var hasInterruptedException = false | ||
@volatile var interruptStatusBeforeExit = false | ||
val t = new UninterruptibleThread("test") { | ||
val t = new UninterruptibleThread("runUninterruptibly") { | ||
override def run(): Unit = { | ||
runUninterruptibly { | ||
enterRunUninterruptibly.countDown() | ||
hasInterruptedException = sleep(1000) | ||
hasInterruptedException = await(interruptLatch) | ||
} | ||
interruptStatusBeforeExit = Thread.interrupted() | ||
} | ||
} | ||
t.start() | ||
assert(enterRunUninterruptibly.await(10, TimeUnit.SECONDS), "await timeout") | ||
assert(!await(enterRunUninterruptibly), "await interrupted") | ||
t.interrupt() | ||
interruptLatch.countDown() | ||
t.join() | ||
assert(hasInterruptedException === false) | ||
assert(interruptStatusBeforeExit) | ||
assert(!hasInterruptedException, "runUninterruptibly should not be interrupted") | ||
assert(interruptStatusBeforeExit, "interrupt flag should be set") | ||
} | ||
|
||
test("interrupt before runUninterruptibly runs") { | ||
val interruptLatch = new CountDownLatch(1) | ||
@volatile var hasInterruptedException = false | ||
@volatile var interruptStatusBeforeExit = false | ||
val t = new UninterruptibleThread("test") { | ||
val t = new UninterruptibleThread("runUninterruptibly") { | ||
override def run(): Unit = { | ||
Uninterruptibles.awaitUninterruptibly(interruptLatch, 10, TimeUnit.SECONDS) | ||
try { | ||
runUninterruptibly { | ||
} | ||
} catch { | ||
case _: InterruptedException => hasInterruptedException = true | ||
assert(Uninterruptibles.awaitUninterruptibly(interruptLatch, 10, TimeUnit.SECONDS)) | ||
assert(isInterrupted, "interrupt flag should be set") | ||
runUninterruptibly { | ||
hasInterruptedException = sleep(0) | ||
} | ||
interruptStatusBeforeExit = Thread.interrupted() | ||
} | ||
|
@@ -80,130 +97,128 @@ class UninterruptibleThreadSuite extends SparkFunSuite { | |
t.interrupt() | ||
interruptLatch.countDown() | ||
t.join() | ||
assert(hasInterruptedException === false) | ||
assert(interruptStatusBeforeExit) | ||
assert(!hasInterruptedException, "runUninterruptibly should not be interrupted") | ||
assert(interruptStatusBeforeExit, "interrupt flag should be set") | ||
} | ||
|
||
test("nested runUninterruptibly") { | ||
val enterRunUninterruptibly = new CountDownLatch(1) | ||
val interruptLatch = new CountDownLatch(1) | ||
@volatile var hasInterruptedException = false | ||
@volatile var interruptStatusBeforeExit = false | ||
val t = new UninterruptibleThread("test") { | ||
val t = new UninterruptibleThread("runUninterruptibly") { | ||
override def run(): Unit = { | ||
runUninterruptibly { | ||
enterRunUninterruptibly.countDown() | ||
Uninterruptibles.awaitUninterruptibly(interruptLatch, 10, TimeUnit.SECONDS) | ||
hasInterruptedException = sleep(1) | ||
runUninterruptibly { | ||
if (sleep(1)) { | ||
hasInterruptedException = true | ||
hasInterruptedException = await(interruptLatch) | ||
if (!hasInterruptedException) { | ||
runUninterruptibly { | ||
hasInterruptedException = sleep(0) | ||
} | ||
} | ||
if (sleep(1)) { | ||
hasInterruptedException = true | ||
hasInterruptedException |= sleep(0) | ||
} | ||
} | ||
interruptStatusBeforeExit = Thread.interrupted() | ||
} | ||
} | ||
t.start() | ||
assert(enterRunUninterruptibly.await(10, TimeUnit.SECONDS), "await timeout") | ||
assert(!await(enterRunUninterruptibly), "await interrupted") | ||
t.interrupt() | ||
interruptLatch.countDown() | ||
assert(!sleep(0), "sleep should not be interrupted") | ||
t.interrupt() | ||
assert(!sleep(0), "sleep should not be interrupted") | ||
t.interrupt() | ||
t.join() | ||
assert(hasInterruptedException === false) | ||
assert(interruptStatusBeforeExit) | ||
assert(!hasInterruptedException, "runUninterruptibly should not be interrupted") | ||
assert(interruptStatusBeforeExit, "interrupt flag should be set") | ||
} | ||
|
||
test("no runUninterruptibly") { | ||
@volatile var hasInterruptedException = false | ||
val t = new UninterruptibleThread("test") { | ||
@volatile var interruptStatusBeforeExit = false | ||
val t = new UninterruptibleThread("run") { | ||
override def run(): Unit = { | ||
if (sleep(0)) { | ||
hasInterruptedException = true | ||
} | ||
hasInterruptedException = sleep(0) | ||
interruptStatusBeforeExit = Thread.interrupted() | ||
} | ||
} | ||
t.interrupt() | ||
t.start() | ||
t.join() | ||
assert(hasInterruptedException === true) | ||
assert(hasInterruptedException, "run should be interrupted") | ||
assert(!interruptStatusBeforeExit, "interrupt flag should not be set") | ||
} | ||
|
||
test("SPARK-51821 uninterruptibleLock deadlock") { | ||
val latch = new CountDownLatch(1) | ||
val task = new UninterruptibleThread("task thread") { | ||
val interruptLatch = new CountDownLatch(1) | ||
val t = new UninterruptibleThread("run") { | ||
override def run(): Unit = { | ||
val channel = new AbstractInterruptibleChannel() { | ||
override def implCloseChannel(): Unit = { | ||
begin() | ||
latch.countDown() | ||
interruptLatch.countDown() | ||
try { | ||
Thread.sleep(Long.MaxValue) | ||
} catch { | ||
case _: InterruptedException => Thread.currentThread().interrupt() | ||
case e: InterruptedException => | ||
log.info("sleep interrupted", e) | ||
Thread.currentThread().interrupt() | ||
} | ||
} | ||
} | ||
channel.close() | ||
} | ||
} | ||
task.start() | ||
assert(latch.await(10, TimeUnit.SECONDS), "await timeout") | ||
task.interrupt() | ||
task.join() | ||
t.start() | ||
assert(!await(interruptLatch), "await interrupted") | ||
t.interrupt() | ||
t.join() | ||
} | ||
|
||
test("stress test") { | ||
for (i <- 0 until 20) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this the major change tries to improve the test coverage? Since it's executed in sequence, the stress of interruptions to a single There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change does not target to increase concurrency of the stress test. It targets to reproduce SPARK-53394 |
||
stressTest(i) | ||
} | ||
} | ||
|
||
def stressTest(i: Int): Unit = { | ||
@volatile var hasInterruptedException = false | ||
val t = new UninterruptibleThread("test") { | ||
val t = new UninterruptibleThread(s"stress test $i") { | ||
override def run(): Unit = { | ||
for (i <- 0 until 100) { | ||
for (i <- 0 until 100 if !hasInterruptedException) { | ||
try { | ||
runUninterruptibly { | ||
if (sleep(Random.nextInt(10))) { | ||
hasInterruptedException = true | ||
} | ||
hasInterruptedException = sleep(Random.nextInt(10)) | ||
runUninterruptibly { | ||
if (sleep(Random.nextInt(10))) { | ||
hasInterruptedException = true | ||
} | ||
} | ||
if (sleep(Random.nextInt(10))) { | ||
hasInterruptedException = true | ||
hasInterruptedException |= sleep(Random.nextInt(10)) | ||
} | ||
hasInterruptedException |= sleep(Random.nextInt(10)) | ||
} | ||
Uninterruptibles.sleepUninterruptibly(Random.nextInt(10), TimeUnit.MILLISECONDS) | ||
// 50% chance to clear the interrupted status | ||
if (Random.nextBoolean()) { | ||
Thread.interrupted() | ||
} | ||
} catch { | ||
case _: InterruptedException => | ||
// The first runUninterruptibly may throw InterruptedException if the interrupt status | ||
// is set before running `f`. | ||
case _: InterruptedException => hasInterruptedException = true | ||
} | ||
} | ||
} | ||
} | ||
t.start() | ||
val threads = new Array[Thread](10) | ||
for (j <- 0 until 10) { | ||
threads(j) = new Thread() { | ||
val threads = Array.fill(Runtime.getRuntime.availableProcessors)( | ||
new Thread() { | ||
override def run(): Unit = { | ||
for (i <- 0 until 400) { | ||
for (i <- 0 until 400 if !hasInterruptedException) { | ||
Thread.sleep(Random.nextInt(10)) | ||
t.interrupt() | ||
} | ||
} | ||
} | ||
threads(j).start() | ||
} | ||
for (j <- 0 until 10) { | ||
threads(j).join() | ||
} | ||
}) | ||
t.start() | ||
threads.foreach(t => t.start()) | ||
threads.foreach(t => t.join()) | ||
t.join() | ||
assert(hasInterruptedException === false) | ||
assert(!hasInterruptedException, "runUninterruptibly should not be interrupted") | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the major difference after using this
await()
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can equally use
sleep()
andawait()
to test forInterruptedException
as they both throw an exception in case thread is interrupted. The difference is thatsleep()
does not return whileawait()
will exit once the main thread callsinterrupt()
and count down the latch.