Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Cats Effect Random for randomness in fullJitter #552

Merged
merged 1 commit into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,13 @@ inThisBuild(
)
)

val catsVersion = "2.12.0"
val catsEffectVersion = "3.5.7"
val catsMtlVersion = "1.5.0"
val munitVersion = "1.0.0"
val munitCatsEffectVersion = "2.0.0"
val disciplineVersion = "2.0.0"
val catsVersion = "2.12.0"
val catsEffectVersion = "3.5.7"
val catsMtlVersion = "1.5.0"
val munitVersion = "1.0.0"
val munitCatsEffectVersion = "2.0.0"
val disciplineVersion = "2.0.0"
val scalacheckEffectVersion = "1.0.4"

val core = crossProject(JVMPlatform, JSPlatform)
.in(file("modules/core"))
Expand All @@ -60,10 +61,11 @@ val core = crossProject(JVMPlatform, JSPlatform)
libraryDependencies ++= Seq(
"org.typelevel" %%% "cats-core" % catsVersion,
"org.typelevel" %%% "cats-effect" % catsEffectVersion,
"org.scalameta" %%% "munit-scalacheck" % munitVersion % Test,
"org.typelevel" %%% "munit-cats-effect" % munitCatsEffectVersion % Test,
"org.typelevel" %%% "cats-laws" % catsVersion % Test,
"org.typelevel" %%% "discipline-munit" % disciplineVersion % Test
"org.scalameta" %%% "munit-scalacheck" % munitVersion % Test,
"org.typelevel" %%% "munit-cats-effect" % munitCatsEffectVersion % Test,
"org.typelevel" %%% "scalacheck-effect" % scalacheckEffectVersion % Test,
"org.typelevel" %%% "cats-laws" % catsVersion % Test,
"org.typelevel" %%% "discipline-munit" % disciplineVersion % Test
),
mimaPreviousArtifacts := Set.empty,
Test / tpolecatExcludeOptions += ScalacOptions.warnNonUnitStatement
Expand Down
16 changes: 9 additions & 7 deletions modules/core/shared/src/main/scala/retry/RetryPolicies.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ package retry
import java.util.concurrent.TimeUnit

import cats.Applicative
import cats.effect.std.Random
import cats.syntax.functor.*
import cats.syntax.show.*
import retry.PolicyDecision.*

import scala.concurrent.duration.{Duration, FiniteDuration}
import scala.util.Random

object RetryPolicies:
private val LongMax: BigInt = BigInt(Long.MaxValue)
Expand Down Expand Up @@ -89,13 +89,15 @@ object RetryPolicies:
/** "Full jitter" backoff algorithm. See
* https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
*/
def fullJitter[F[_]: Applicative](baseDelay: FiniteDuration): RetryPolicy[F, Any] =
RetryPolicy.liftWithShow(
def fullJitter[F[_]: Applicative: Random](baseDelay: FiniteDuration): RetryPolicy[F, Any] =
RetryPolicy.withShow(
{ (_, status) =>
val e = Math.pow(2.0, status.retriesSoFar.toDouble).toLong
val maxDelay = safeMultiply(baseDelay, e)
val delayNanos = (maxDelay.toNanos * Random.nextDouble()).toLong
DelayAndRetry(new FiniteDuration(delayNanos, TimeUnit.NANOSECONDS))
val e = Math.pow(2.0, status.retriesSoFar.toDouble).toLong
val maxDelay = safeMultiply(baseDelay, e)
Random[F].nextDouble.map { rnd =>
val delayNanos = (maxDelay.toNanos * rnd).toLong
DelayAndRetry(new FiniteDuration(delayNanos, TimeUnit.NANOSECONDS))
}
},
show"fullJitter(baseDelay=$baseDelay)"
)
Expand Down
84 changes: 53 additions & 31 deletions modules/core/shared/src/test/scala/retry/RetryPoliciesSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,21 @@ package retry

import java.util.concurrent.TimeUnit

import retry.RetryPolicies.*
import cats.Id
import cats.effect.IO
import cats.effect.std.Random
import cats.syntax.all.*
import org.scalacheck.{Arbitrary, Gen}
import org.scalacheck.Prop.forAll
import munit.ScalaCheckSuite
import org.scalacheck.effect.PropF
import munit.{CatsEffectSuite, ScalaCheckSuite}
import retry.PolicyDecision.{DelayAndRetry, GiveUp}
import retry.RetryPolicies.*

import scala.concurrent.duration.*
import munit.Location

class RetryPoliciesSuite extends ScalaCheckSuite:
class RetryPoliciesSuite extends CatsEffectSuite with ScalaCheckSuite:

given Arbitrary[RetryStatus] = Arbitrary {
for
Expand All @@ -29,14 +33,14 @@ class RetryPoliciesSuite extends ScalaCheckSuite:
val genFiniteDuration: Gen[FiniteDuration] =
Gen.posNum[Long].map(FiniteDuration(_, TimeUnit.NANOSECONDS))

given Arbitrary[RetryPolicy[Id, Any]] = Arbitrary {
given (using Random[IO]): Arbitrary[RetryPolicy[IO, Any]] = Arbitrary {
Gen.oneOf(
Gen.const(alwaysGiveUp[Id]),
genFiniteDuration.map(delay => constantDelay[Id](delay)),
genFiniteDuration.map(baseDelay => exponentialBackoff[Id](baseDelay)),
Gen.posNum[Int].map(maxRetries => limitRetries[Id](maxRetries)),
genFiniteDuration.map(baseDelay => fibonacciBackoff[Id](baseDelay)),
genFiniteDuration.map(baseDelay => fullJitter[Id](baseDelay))
Gen.const(alwaysGiveUp[IO]),
genFiniteDuration.map(delay => constantDelay[IO](delay)),
genFiniteDuration.map(baseDelay => exponentialBackoff[IO](baseDelay)),
Gen.posNum[Int].map(maxRetries => limitRetries[IO](maxRetries)),
genFiniteDuration.map(baseDelay => fibonacciBackoff[IO](baseDelay)),
genFiniteDuration.map(baseDelay => fullJitter[IO](baseDelay))
)
}

Expand Down Expand Up @@ -94,39 +98,57 @@ class RetryPoliciesSuite extends ScalaCheckSuite:
}

test("fullJitter - implement the AWS Full Jitter backoff algorithm") {
val policy = fullJitter[Id](100.milliseconds)
val mkPolicy: IO[RetryPolicy[IO, Any]] = Random.scalaUtilRandom[IO].map { rnd =>
given Random[IO] = rnd
fullJitter[IO](100.milliseconds)
}
val arbitraryCumulativeDelay = 999.milliseconds
val arbitraryPreviousDelay = Some(999.milliseconds)

def check(retriesSoFar: Int, expectedMaximumDelay: FiniteDuration): Unit =
case class TestCase(retriesSoFar: Int, expectedMaximumDelay: FiniteDuration)

def check(testCase: TestCase): IO[Unit] =
val status = RetryStatus(
retriesSoFar,
testCase.retriesSoFar,
arbitraryCumulativeDelay,
arbitraryPreviousDelay
)
for _ <- 1 to 1000 do
val verdict = policy.decideNextRetry((), status)
val delay = verdict.asInstanceOf[PolicyDecision.DelayAndRetry].delay
assert(delay >= Duration.Zero)
assert(delay < expectedMaximumDelay)
(1 to 1000).toList.traverse_ { i =>
for
policy <- mkPolicy
verdict <- policy.decideNextRetry((), status)
yield
val delay = verdict.asInstanceOf[PolicyDecision.DelayAndRetry].delay
assert(clue(delay) >= Duration.Zero)
assert(clue(delay) < clue(testCase.expectedMaximumDelay))
}

check(0, 100.milliseconds)
check(1, 200.milliseconds)
check(2, 400.milliseconds)
check(3, 800.milliseconds)
check(4, 1600.milliseconds)
check(5, 3200.milliseconds)
val cases = List(
TestCase(0, 100.milliseconds),
TestCase(1, 200.milliseconds),
TestCase(2, 400.milliseconds),
TestCase(3, 800.milliseconds),
TestCase(4, 1600.milliseconds),
TestCase(5, 3200.milliseconds)
)

cases.traverse_(check)
}

property(
test(
"all built-in policies - never try to create a FiniteDuration of more than Long.MaxValue nanoseconds"
) {
forAll((policy: RetryPolicy[Id, Any], status: RetryStatus) =>
policy.decideNextRetry((), status) match
case PolicyDecision.DelayAndRetry(nextDelay) =>
nextDelay.toNanos <= Long.MaxValue
case PolicyDecision.GiveUp => true
)
Random.scalaUtilRandom[IO].map { rnd =>
given Random[IO] = rnd
PropF.forAllF((policy: RetryPolicy[IO, Any], status: RetryStatus) =>
policy.decideNextRetry((), status).map {
case PolicyDecision.DelayAndRetry(nextDelay) =>
assert(nextDelay.toNanos <= Long.MaxValue)
case PolicyDecision.GiveUp =>
assert(true)
}
)
}
}

property("limitRetries - retry with no delay until the limit is reached") {
Expand Down
Loading