Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@
package cats.effect.benchmarks

import cats.effect.IO
import cats.effect.syntax.all._
import cats.effect.unsafe.implicits.global
import cats.implicits.{catsSyntaxParallelTraverse1, toTraverseOps}

import org.openjdk.jmh.annotations._
import org.openjdk.jmh.infra.Blackhole

import scala.concurrent.duration._

import java.util.concurrent.TimeUnit

/**
Expand Down Expand Up @@ -55,6 +58,24 @@ class ParallelBenchmark {
def parTraverse(): Unit =
1.to(size).toList.parTraverse(_ => IO(Blackhole.consumeCPU(cpuTokens))).void.unsafeRunSync()

@Benchmark
def parTraverseN(): Unit =
1.to(size)
.toList
.parTraverseN(size / 100)(_ => IO(Blackhole.consumeCPU(cpuTokens)))
.void
.unsafeRunSync()

@Benchmark
def parTraverseNCancel(): Unit = {
val e = new RuntimeException
val test = 1.to(size * 100).toList.parTraverseN(size / 100) { _ =>
IO.sleep(100.millis) *> IO.raiseError(e)
}

test.attempt.void.unsafeRunSync()
}

@Benchmark
def traverse(): Unit =
1.to(size).toList.traverse(_ => IO(Blackhole.consumeCPU(cpuTokens))).void.unsafeRunSync()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,52 @@ trait GenConcurrent[F[_], E] extends GenSpawn[F, E] {

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scaladoc needs an update above.

implicit val F: GenConcurrent[F, E] = this

MiniSemaphore[F](n).flatMap { sem => ta.parTraverse { a => sem.withPermit(f(a)) } }
F.deferred[Option[E]] flatMap { preempt =>
F.ref[Set[Fiber[F, ?, ?]]](Set()) flatMap { supervision =>
MiniSemaphore[F](n) flatMap { sem =>
val results = ta traverse { a =>
preempt.tryGet flatMap {
case Some(_) =>
// it's okay to produce never here because the early abort preceeds us
// this effect won't get sequenced, so it can be anything really
F.pure(F.never[B])

case None =>
F.uncancelable { poll =>
F.deferred[Outcome[F, E, B]] flatMap { result =>
val action = poll(sem.acquire) >> f(a)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this intentionally >> and not *>? So that evaluating the pure f is also restricted by the semaphore? (In my opinion it doesn't need to be, but it's okay that it is.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's intentional. I should probably comment it as such. I think most users probably believe that even the pure part of the function is parallelized (and subject to the semaphore).

.guaranteeCase { oc =>
result.complete(oc) *> oc.fold(
preempt.complete(None).void,
e => preempt.complete(Some(e)).void,
_ => F.unit) *> sem.release
}
.void
.voidError
.start

action flatMap { fiber =>
supervision.update(_ + fiber) map { _ =>
result
.get
.flatMap(_.embed(F.canceled *> F.never))
.onCancel(fiber.cancel)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When is this onCancel necessary? Wouldn't the guaranteeCase below cancel everything in supervision anyway?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's not necessary. I've been building this up a bit incrementally so there's some overlapping logic I need to deduplicate due to the number of corner cases this function has.

.guarantee(supervision.update(_ - fiber))
}
}
}
}
}
}

results.flatMap(_.sequence) guaranteeCase {
case Outcome.Succeeded(_) => F.unit
// has to be done in parallel to avoid head of line issues
case _ => supervision.get.flatMap(_.toList.parTraverse_(_.cancel))
}
}
}
}
}

/**
Expand All @@ -152,7 +197,52 @@ trait GenConcurrent[F[_], E] extends GenSpawn[F, E] {

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scaladoc above.

implicit val F: GenConcurrent[F, E] = this

MiniSemaphore[F](n).flatMap { sem => ta.parTraverse_ { a => sem.withPermit(f(a)) } }
// TODO we need to write a test for error cancelation
F.deferred[Option[E]] flatMap { preempt =>
F.ref[List[Fiber[F, ?, ?]]](Nil) flatMap { supervision =>
MiniSemaphore[F](n) flatMap { sem =>
val startAll = ta traverse_ { a =>
// first check to see if any of the effects have errored out
// don't bother starting new things if that happens
preempt.tryGet flatMap {
case Some(_) =>
F.unit // allow the error to be resurfaced later

case None =>
F.uncancelable { poll =>
// if the effect produces an error, race to kill all the rest
val wrapped = f(a) guaranteeCase { oc =>
sem.release *> oc.fold(
preempt.complete(None).void,
e => preempt.complete(Some(e)).void,
_ => F.unit)
}

val suppressed = wrapped.void.voidError

poll(sem.acquire) >> suppressed.start flatMap { fiber =>
// supervision is handled very differently here: we never remove from the set
supervision.update(fiber :: _)
}
}
}
}

val cancelAll = supervision.get.flatMap(_.parTraverse_(_.cancel))

startAll.onCancel(cancelAll) *>
// we block until it's all done by acquiring all the permits
F.race(preempt.get *> cancelAll, sem.acquire.replicateA_(n)) *>
// if we hit an error or self-cancelation in any effect, resurface it here
// note that we can't lose errors here because of the permits: we know the fibers are done
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there may be a race here:

  1. The very last task fails with an error, and releases its permit (sem.release above in wrapped).
  2. Acquiring all the permits here wins the F.race here (just above).
  3. Just below we preempt.tryGet, and read None, and complete with F.unit.
  4. The task completes preempt with the error (above in wrapped).
  5. (But no one will see that any more.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a good point. I do this in both implementations. My thinking was that it increases parallelism somewhat (releasing the permit asap), but it does general this race condition. I'll fix it in both.

preempt.tryGet flatMap {
case Some(Some(e)) => F.raiseError(e)
case Some(None) => F.canceled
case None => F.unit
}
}
}
}
}

override def racePair[A, B](fa: F[A], fb: F[B])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,9 @@ import scala.collection.immutable.{Queue => ScalaQueue}
* A cut-down version of semaphore used to implement parTraverseN
*/
private[kernel] abstract class MiniSemaphore[F[_]] extends Serializable {
def acquire: F[Unit]
def release: F[Unit]

/**
* Sequence an action while holding a permit
*/
def withPermit[A](fa: F[A]): F[A]
}

Expand Down
140 changes: 140 additions & 0 deletions tests/shared/src/test/scala/cats/effect/IOSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1614,6 +1614,146 @@ class IOSpec extends BaseSpec with Discipline with IOPlatformSpecification {
p must completeAs(true)
}

"run finalizers when canceled" in ticked { implicit ticker =>
val p = for {
r <- IO.ref(0)

/*
* The exact series of steps here is:
*
* List(IO.never.onCancel, IO.unit, IO.never.onCancel)
*
* This is significant because we're limiting the parallelism to
* 2, meaning that we will hit a wall after IO.unit. HOWEVER,
* IO.unit completes immediately, so this test not only checks
* cancelation, it also tests that we move onto the third item
* after the second one completes even while the first is blocked.
* In other words, it's testing both cancelation and head of line
* behavior.
*/
f <- List(1, 2, 3)
.parTraverseN(2) { i =>
if (i == 2) IO.unit
else IO.never.onCancel(r.update(_ + 1))
}
.start

_ <- IO.sleep(100.millis)
_ <- f.cancel
c <- r.get
_ <- IO { c mustEqual 2 }
} yield true

p must completeAs(true)
}

"propagate self-cancellation" in ticked { implicit ticker =>
List(1, 2, 3, 4)
.parTraverseN(2) { (n: Int) =>
if (n == 3) IO.canceled *> IO.never
else IO.pure(n)
}
.void must selfCancel
}

"run finalizers when a task self-cancels" in ticked { implicit ticker =>
val p = for {
r <- IO.ref(0)
fib <- List(1, 2, 3, 4)
.parTraverseN(2) { (n: Int) =>
if (n == 3) IO.canceled *> IO.never
else IO.pure(n)
}
.onCancel(r.update(_ + 1))
.void
.start
_ <- IO.sleep(100.millis)
c <- r.get
_ <- IO { c mustEqual 1 }
oc <- fib.join
} yield oc.isCanceled

p must completeAs(true)
}

"not run more than `n` tasks at a time" in real {
def task(counter: Ref[IO, Int], maximum: Ref[IO, Int]): IO[Unit] = {
val acq = counter.updateAndGet(_ + 1).flatMap { count =>
maximum.update { max => if (count > max) count else max }
}
IO.asyncForIO.bracket(acq) { _ => IO.sleep(100.millis) }(_ => counter.update(_ - 1))
}

for {
maximum <- Ref.of[IO, Int](0)
counter <- Ref.of[IO, Int](0)
nCpu <- IO { Runtime.getRuntime().availableProcessors() }
n = java.lang.Math.max(nCpu, 2)
size = 4 * n
res <- (1 to size).toList.parTraverseN(n) { _ => task(counter, maximum) }
_ <- IO { res.size mustEqual size }
count <- counter.get
_ <- IO { count mustEqual 0 }
max <- maximum.get
_ <- IO { max must beLessThanOrEqualTo(n) }
} yield ok
}

"run actually in parallel" in real {
val n = 4
(1 to 2 * n)
.toList
.map { i => IO.sleep(1.second).as(i) }
.parSequenceN(n)
.timeout(3.seconds)
.flatMap { res => IO { res mustEqual (1 to 2 * n).toList } }
}

"work for empty traverse" in ticked { implicit ticker =>
List.empty[Int].parTraverseN(4) { _ => IO.never[String] } must completeAs(
List.empty[String])
}

"work for non-empty traverse (ticked)" in ticked { implicit ticker =>
List(1).parTraverseN(4) { i => IO.pure(i.toString) } must completeAs(List("1"))
List(1, 2).parTraverseN(3) { i => IO.pure(i.toString) } must completeAs(List("1", "2"))
List(1, 2, 3).parTraverseN(2) { i => IO.pure(i.toString) } must completeAs(
List("1", "2", "3"))
List(1, 2, 3, 4).parTraverseN(1) { i => IO.pure(i.toString) } must completeAs(
List("1", "2", "3", "4"))
}

"work for non-empty traverse (real)" in real {
for {
_ <- List(1).parTraverseN(4)(i => IO.pure(i.toString)).flatMap { r =>
IO(r mustEqual List("1"))
}
_ <- List(1, 2).parTraverseN(3)(i => IO.pure(i.toString)).flatMap { r =>
IO(r mustEqual List("1", "2"))
}
_ <- List(1, 2, 3).parTraverseN(2)(i => IO.pure(i.toString)).flatMap { r =>
IO(r mustEqual List("1", "2", "3"))
}
_ <- List(1, 2, 3, 4).parTraverseN(1)(i => IO.pure(i.toString)).flatMap { r =>
IO(r mustEqual List("1", "2", "3", "4"))
}
_ <- (1 to 10000).toList.parTraverseN(2)(i => IO.pure(i.toString)).flatMap { r =>
IO(r mustEqual (1 to 10000).map(_.toString).toList)
}
} yield ok
}

"be null-safe" in real {
for {
r1 <- List[String]("a", "b", null, "d", null).parTraverseN(2) {
case "a" => IO.pure(null)
case "b" => IO.pure("x")
case "d" => IO.pure(null)
case null => IO.pure("z")
}
_ <- IO { r1 mustEqual List(null, "x", "z", null, "z") }
} yield ok
}
}

"parTraverseN_" should {
Expand Down
Loading