Skip to content

Commit

Permalink
Use Cats Effect Random for randomness in fullJitter (#552)
Browse files Browse the repository at this point in the history
Replace scala.util.Random with CE Random
  • Loading branch information
cb372 authored Jan 9, 2025
1 parent 256e847 commit 116eb90
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 48 deletions.
22 changes: 12 additions & 10 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,13 @@ inThisBuild(
)
)

val catsVersion = "2.12.0"
val catsEffectVersion = "3.5.7"
val catsMtlVersion = "1.5.0"
val munitVersion = "1.0.0"
val munitCatsEffectVersion = "2.0.0"
val disciplineVersion = "2.0.0"
val catsVersion = "2.12.0"
val catsEffectVersion = "3.5.7"
val catsMtlVersion = "1.5.0"
val munitVersion = "1.0.0"
val munitCatsEffectVersion = "2.0.0"
val disciplineVersion = "2.0.0"
val scalacheckEffectVersion = "1.0.4"

val core = crossProject(JVMPlatform, JSPlatform)
.in(file("modules/core"))
Expand All @@ -60,10 +61,11 @@ val core = crossProject(JVMPlatform, JSPlatform)
libraryDependencies ++= Seq(
"org.typelevel" %%% "cats-core" % catsVersion,
"org.typelevel" %%% "cats-effect" % catsEffectVersion,
"org.scalameta" %%% "munit-scalacheck" % munitVersion % Test,
"org.typelevel" %%% "munit-cats-effect" % munitCatsEffectVersion % Test,
"org.typelevel" %%% "cats-laws" % catsVersion % Test,
"org.typelevel" %%% "discipline-munit" % disciplineVersion % Test
"org.scalameta" %%% "munit-scalacheck" % munitVersion % Test,
"org.typelevel" %%% "munit-cats-effect" % munitCatsEffectVersion % Test,
"org.typelevel" %%% "scalacheck-effect" % scalacheckEffectVersion % Test,
"org.typelevel" %%% "cats-laws" % catsVersion % Test,
"org.typelevel" %%% "discipline-munit" % disciplineVersion % Test
),
mimaPreviousArtifacts := Set.empty,
Test / tpolecatExcludeOptions += ScalacOptions.warnNonUnitStatement
Expand Down
16 changes: 9 additions & 7 deletions modules/core/shared/src/main/scala/retry/RetryPolicies.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ package retry
import java.util.concurrent.TimeUnit

import cats.Applicative
import cats.effect.std.Random
import cats.syntax.functor.*
import cats.syntax.show.*
import retry.PolicyDecision.*

import scala.concurrent.duration.{Duration, FiniteDuration}
import scala.util.Random

object RetryPolicies:
private val LongMax: BigInt = BigInt(Long.MaxValue)
Expand Down Expand Up @@ -89,13 +89,15 @@ object RetryPolicies:
/** "Full jitter" backoff algorithm. See
* https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
*/
def fullJitter[F[_]: Applicative](baseDelay: FiniteDuration): RetryPolicy[F, Any] =
RetryPolicy.liftWithShow(
def fullJitter[F[_]: Applicative: Random](baseDelay: FiniteDuration): RetryPolicy[F, Any] =
RetryPolicy.withShow(
{ (_, status) =>
val e = Math.pow(2.0, status.retriesSoFar.toDouble).toLong
val maxDelay = safeMultiply(baseDelay, e)
val delayNanos = (maxDelay.toNanos * Random.nextDouble()).toLong
DelayAndRetry(new FiniteDuration(delayNanos, TimeUnit.NANOSECONDS))
val e = Math.pow(2.0, status.retriesSoFar.toDouble).toLong
val maxDelay = safeMultiply(baseDelay, e)
Random[F].nextDouble.map { rnd =>
val delayNanos = (maxDelay.toNanos * rnd).toLong
DelayAndRetry(new FiniteDuration(delayNanos, TimeUnit.NANOSECONDS))
}
},
show"fullJitter(baseDelay=$baseDelay)"
)
Expand Down
84 changes: 53 additions & 31 deletions modules/core/shared/src/test/scala/retry/RetryPoliciesSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,21 @@ package retry

import java.util.concurrent.TimeUnit

import retry.RetryPolicies.*
import cats.Id
import cats.effect.IO
import cats.effect.std.Random
import cats.syntax.all.*
import org.scalacheck.{Arbitrary, Gen}
import org.scalacheck.Prop.forAll
import munit.ScalaCheckSuite
import org.scalacheck.effect.PropF
import munit.{CatsEffectSuite, ScalaCheckSuite}
import retry.PolicyDecision.{DelayAndRetry, GiveUp}
import retry.RetryPolicies.*

import scala.concurrent.duration.*
import munit.Location

class RetryPoliciesSuite extends ScalaCheckSuite:
class RetryPoliciesSuite extends CatsEffectSuite with ScalaCheckSuite:

given Arbitrary[RetryStatus] = Arbitrary {
for
Expand All @@ -29,14 +33,14 @@ class RetryPoliciesSuite extends ScalaCheckSuite:
val genFiniteDuration: Gen[FiniteDuration] =
Gen.posNum[Long].map(FiniteDuration(_, TimeUnit.NANOSECONDS))

given Arbitrary[RetryPolicy[Id, Any]] = Arbitrary {
given (using Random[IO]): Arbitrary[RetryPolicy[IO, Any]] = Arbitrary {
Gen.oneOf(
Gen.const(alwaysGiveUp[Id]),
genFiniteDuration.map(delay => constantDelay[Id](delay)),
genFiniteDuration.map(baseDelay => exponentialBackoff[Id](baseDelay)),
Gen.posNum[Int].map(maxRetries => limitRetries[Id](maxRetries)),
genFiniteDuration.map(baseDelay => fibonacciBackoff[Id](baseDelay)),
genFiniteDuration.map(baseDelay => fullJitter[Id](baseDelay))
Gen.const(alwaysGiveUp[IO]),
genFiniteDuration.map(delay => constantDelay[IO](delay)),
genFiniteDuration.map(baseDelay => exponentialBackoff[IO](baseDelay)),
Gen.posNum[Int].map(maxRetries => limitRetries[IO](maxRetries)),
genFiniteDuration.map(baseDelay => fibonacciBackoff[IO](baseDelay)),
genFiniteDuration.map(baseDelay => fullJitter[IO](baseDelay))
)
}

Expand Down Expand Up @@ -94,39 +98,57 @@ class RetryPoliciesSuite extends ScalaCheckSuite:
}

test("fullJitter - implement the AWS Full Jitter backoff algorithm") {
val policy = fullJitter[Id](100.milliseconds)
val mkPolicy: IO[RetryPolicy[IO, Any]] = Random.scalaUtilRandom[IO].map { rnd =>
given Random[IO] = rnd
fullJitter[IO](100.milliseconds)
}
val arbitraryCumulativeDelay = 999.milliseconds
val arbitraryPreviousDelay = Some(999.milliseconds)

def check(retriesSoFar: Int, expectedMaximumDelay: FiniteDuration): Unit =
case class TestCase(retriesSoFar: Int, expectedMaximumDelay: FiniteDuration)

def check(testCase: TestCase): IO[Unit] =
val status = RetryStatus(
retriesSoFar,
testCase.retriesSoFar,
arbitraryCumulativeDelay,
arbitraryPreviousDelay
)
for _ <- 1 to 1000 do
val verdict = policy.decideNextRetry((), status)
val delay = verdict.asInstanceOf[PolicyDecision.DelayAndRetry].delay
assert(delay >= Duration.Zero)
assert(delay < expectedMaximumDelay)
(1 to 1000).toList.traverse_ { i =>
for
policy <- mkPolicy
verdict <- policy.decideNextRetry((), status)
yield
val delay = verdict.asInstanceOf[PolicyDecision.DelayAndRetry].delay
assert(clue(delay) >= Duration.Zero)
assert(clue(delay) < clue(testCase.expectedMaximumDelay))
}

check(0, 100.milliseconds)
check(1, 200.milliseconds)
check(2, 400.milliseconds)
check(3, 800.milliseconds)
check(4, 1600.milliseconds)
check(5, 3200.milliseconds)
val cases = List(
TestCase(0, 100.milliseconds),
TestCase(1, 200.milliseconds),
TestCase(2, 400.milliseconds),
TestCase(3, 800.milliseconds),
TestCase(4, 1600.milliseconds),
TestCase(5, 3200.milliseconds)
)

cases.traverse_(check)
}

property(
test(
"all built-in policies - never try to create a FiniteDuration of more than Long.MaxValue nanoseconds"
) {
forAll((policy: RetryPolicy[Id, Any], status: RetryStatus) =>
policy.decideNextRetry((), status) match
case PolicyDecision.DelayAndRetry(nextDelay) =>
nextDelay.toNanos <= Long.MaxValue
case PolicyDecision.GiveUp => true
)
Random.scalaUtilRandom[IO].map { rnd =>
given Random[IO] = rnd
PropF.forAllF((policy: RetryPolicy[IO, Any], status: RetryStatus) =>
policy.decideNextRetry((), status).map {
case PolicyDecision.DelayAndRetry(nextDelay) =>
assert(nextDelay.toNanos <= Long.MaxValue)
case PolicyDecision.GiveUp =>
assert(true)
}
)
}
}

property("limitRetries - retry with no delay until the limit is reached") {
Expand Down

0 comments on commit 116eb90

Please sign in to comment.