Skip to content

Commit a0b526b

Browse files
committed
Add an efficient reservoir sampling aggregator
This aggregator uses Li's "Algorithm L", a simple yet efficient sampling method, with modifications to support a monoidal setting. A JMH benchmark was added for both this and the old priority-queue algoritm. In a single-threaded benchmark on an Intel Core i9-10885H, the algorithms are roughly on par for a sample rate of 10%, but Algorithm L performs much better at lower sample rates (2x-5x througput increase observed at various collection sizes). Because of this, the new algorithm was made the default for Aggregtor.reservoirSample(). Unit tests were added for both algorithms. These are probabilistic and are expected to fail on some 0.1% of times, per test case (p-value is set to 0.001).
1 parent 464917d commit a0b526b

File tree

8 files changed

+556
-6
lines changed

8 files changed

+556
-6
lines changed
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
package com.twitter.algebird.benchmark
2+
3+
import com.twitter.algebird.mutable.ReservoirSamplingToListAggregator
4+
import com.twitter.algebird.{Aggregator, Preparer}
5+
import org.openjdk.jmh.annotations.{Benchmark, Param, Scope, State}
6+
import org.openjdk.jmh.infra.Blackhole
7+
8+
import scala.util.Random
9+
10+
object ReservoirSamplingBenchmark {
11+
@State(Scope.Benchmark)
12+
class BenchmarkState {
13+
@Param(Array("100", "10000", "1000000"))
14+
var collectionSize: Int = 0
15+
16+
@Param(Array("0.001", "0.01", "0.1"))
17+
var sampleRate: Double = 0.0
18+
19+
def samples: Int = (sampleRate * collectionSize).ceil.toInt
20+
}
21+
22+
val rng = new Random()
23+
implicit val randomSupplier: () => Random = () => rng
24+
}
25+
26+
class ReservoirSamplingBenchmark {
27+
import ReservoirSamplingBenchmark._
28+
29+
private def prioQueueSampler[T](count: Int) =
30+
Preparer[T]
31+
.map(rng.nextDouble() -> _)
32+
.monoidAggregate(Aggregator.sortByTake(count)(_._1))
33+
.andThenPresent(_.map(_._2))
34+
35+
@Benchmark
36+
def timeAlgorithmL(state: BenchmarkState, bh: Blackhole): Unit =
37+
bh.consume(new ReservoirSamplingToListAggregator[Int](state.samples).apply(0 until state.collectionSize))
38+
39+
@Benchmark
40+
def timePriorityQeueue(state: BenchmarkState, bh: Blackhole): Unit =
41+
bh.consume(prioQueueSampler(state.samples).apply(0 until state.collectionSize))
42+
}

algebird-core/src/main/scala/com/twitter/algebird/Aggregator.scala

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package com.twitter.algebird
22

3+
import com.twitter.algebird.mutable.{Reservoir, ReservoirSamplingToListAggregator}
4+
35
import java.util.PriorityQueue
46
import scala.collection.compat._
57
import scala.collection.generic.CanBuildFrom
@@ -286,12 +288,9 @@ object Aggregator extends java.io.Serializable {
286288
def reservoirSample[T](
287289
count: Int,
288290
seed: Int = DefaultSeed
289-
): MonoidAggregator[T, PriorityQueue[(Double, T)], Seq[T]] = {
290-
val rng = new java.util.Random(seed)
291-
Preparer[T]
292-
.map(rng.nextDouble() -> _)
293-
.monoidAggregate(sortByTake(count)(_._1))
294-
.andThenPresent(_.map(_._2))
291+
): MonoidAggregator[T, Reservoir[T], Seq[T]] = {
292+
val rng = new scala.util.Random(seed)
293+
new ReservoirSamplingToListAggregator[T](count)(() => rng)
295294
}
296295

297296
/**
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
package com.twitter.algebird.mutable
2+
3+
import com.twitter.algebird.{Monoid, MonoidAggregator}
4+
5+
import scala.collection.mutable
6+
import scala.util.Random
7+
8+
/**
9+
* A reservoir of the currently sampled items.
10+
*
11+
* @param capacity
12+
* the reservoir capacity
13+
* @tparam T
14+
* the element type
15+
*/
16+
sealed class Reservoir[T](val capacity: Int) {
17+
var reservoir: mutable.Buffer[T] = mutable.Buffer()
18+
19+
// When the reservoir is full, w is the threshold for accepting an element into the reservoir, and
20+
// the following invariant holds: The maximum score of the elements in the reservoir is w,
21+
// and the remaining elements are distributed as U[0, w].
22+
// Scores are not kept explicitly, only their distribution is tracked and sampled from.
23+
// (w = 1 when the reservoir is not full.)
24+
var w: Double = 1
25+
26+
require(capacity > 0, "reservoir size must be positive")
27+
private val kInv: Double = 1d / capacity
28+
29+
def size: Int = reservoir.size
30+
def isEmpty: Boolean = reservoir.isEmpty
31+
def isFull: Boolean = size == capacity
32+
33+
/**
34+
* Add an element to the reservoir. If the reservoir is full then the element will replace a random element
35+
* in the reservoir, and the threshold <pre>w</pre> is updated.
36+
*
37+
* When adding multiple elements, [[append]] should be used to take advantage of exponential jumps.
38+
*
39+
* @param x
40+
* the element to add
41+
* @param rng
42+
* the random source
43+
*/
44+
def accept(x: T, rng: Random): Unit = {
45+
if (isFull) {
46+
reservoir(rng.nextInt(capacity)) = x
47+
} else {
48+
reservoir.append(x)
49+
}
50+
if (isFull) {
51+
w *= Math.pow(rng.nextDouble, kInv)
52+
}
53+
}
54+
55+
/**
56+
* Add multiple elements to the reservoir.
57+
* @param xs
58+
* the elements to add
59+
* @param rng
60+
* the random source
61+
* @param prior
62+
* the threshold of the elements being added, such that the added element's value is distributed as
63+
* <pre>U[0, prior]</pre>
64+
* @return
65+
* this reservoir
66+
*/
67+
def append(xs: TraversableOnce[T], rng: Random, prior: Double = 1): Reservoir[T] = {
68+
// The number of items to skip before accepting the next item is geometrically distributed
69+
// with probability of success w / prior. The prior will be 1 when adding to a single reservoir,
70+
// but when merging reservoirs it will be the threshold of the reservoir being pulled from,
71+
// and in this case we require that w < prior.
72+
def nextAcceptTime = (-rng.self.nextExponential / Math.log1p(-w / prior)).toInt
73+
74+
var skip = if (isFull) nextAcceptTime else 0
75+
for (x <- xs) {
76+
if (!isFull) {
77+
// keep adding while reservoir is not full
78+
accept(x, rng)
79+
if (isFull) {
80+
skip = nextAcceptTime
81+
}
82+
} else if (skip > 0) {
83+
skip -= 1
84+
} else {
85+
accept(x, rng)
86+
skip = nextAcceptTime
87+
}
88+
}
89+
this
90+
}
91+
92+
override def toString: String = s"Reservoir($capacity, $w, ${reservoir.toList})"
93+
}
94+
95+
object Reservoir {
96+
implicit def monoid[T](implicit randomSupplier: () => Random): Monoid[Reservoir[T]] =
97+
new ReservoirMonoid()(randomSupplier)
98+
}
99+
100+
/**
101+
* This is the "Algorithm L" reservoir sampling algorithm [1], with modifications to act as a monoid by
102+
* merging reservoirs.
103+
*
104+
* [1] Kim-Hung Li, "Reservoir-Sampling Algorithms of Time Complexity O(n(1+log(N/n)))", 1994
105+
*
106+
* @tparam T
107+
* the item type
108+
*/
109+
class ReservoirMonoid[T](implicit val randomSupplier: () => Random) extends Monoid[Reservoir[T]] {
110+
111+
/**
112+
* Builds a reservoir with a single item.
113+
*
114+
* @param k
115+
* the reservoir capacity
116+
* @param x
117+
* the item to add
118+
* @return
119+
*/
120+
def build(k: Int, x: T): Reservoir[T] = {
121+
val r = new Reservoir[T](k)
122+
r.accept(x, randomSupplier())
123+
r
124+
}
125+
126+
override def zero: Reservoir[T] = new Reservoir(1)
127+
def zero(k: Int): Reservoir[T] = new Reservoir(k)
128+
override def isNonZero(r: Reservoir[T]): Boolean = !r.isEmpty
129+
130+
/**
131+
* Merge two reservoirs. NOTE: This mutates one or both of the reservoirs. They should not be used after
132+
* this operation, except as the return value for further aggregation.
133+
*/
134+
override def plus(left: Reservoir[T], right: Reservoir[T]): Reservoir[T] =
135+
if (left.isEmpty) right
136+
else if (left.size + right.size <= left.capacity) {
137+
// the sum of the sizes is less than the reservoir size, so we can just merge
138+
left.append(right.reservoir, randomSupplier())
139+
} else {
140+
val (s1, s2) = if (left.w < right.w) (left, right) else (right, left)
141+
val rng = randomSupplier()
142+
if (s2.isFull) {
143+
// The highest score in s2 is w, and the other scores are distributed as U[0, w].
144+
// Since s1.w < s2.w, we have to drop the single (sampled) element with the highest score
145+
// unconditionally. The other elements enter the reservoir with probability s1.w / s2.w.
146+
val i = rng.nextInt(s2.size)
147+
s2.reservoir(i) = s2.reservoir.head
148+
s1.append(s2.reservoir.drop(1), rng, s2.w)
149+
} else {
150+
s1.append(s2.reservoir, rng)
151+
}
152+
}
153+
}
154+
155+
/**
156+
* An aggregator that uses reservoir sampling to sample k elements from a stream of items. Because the
157+
* reservoir is mutable, it is a good idea to copy the result to an immutable view before using it, as is done
158+
* by [[ReservoirSamplingToListAggregator]].
159+
*
160+
* @param k
161+
* the number of elements to sample
162+
* @param randomSupplier
163+
* the random generator
164+
* @tparam T
165+
* the item type
166+
* @tparam C
167+
* the result type
168+
*/
169+
abstract class ReservoirSamplingAggregator[T, +C](k: Int)(implicit val randomSupplier: () => Random)
170+
extends MonoidAggregator[T, Reservoir[T], C] {
171+
override val monoid: ReservoirMonoid[T] = new ReservoirMonoid
172+
override def prepare(x: T): Reservoir[T] = monoid.build(k, x)
173+
174+
override def apply(xs: TraversableOnce[T]): C = present(agg(xs))
175+
176+
override def applyOption(inputs: TraversableOnce[T]): Option[C] =
177+
if (inputs.isEmpty) None else Some(apply(inputs))
178+
179+
override def append(r: Reservoir[T], t: T): Reservoir[T] = r.append(Seq(t), randomSupplier())
180+
181+
override def appendAll(r: Reservoir[T], xs: TraversableOnce[T]): Reservoir[T] =
182+
r.append(xs, randomSupplier())
183+
184+
override def appendAll(xs: TraversableOnce[T]): Reservoir[T] = agg(xs)
185+
186+
private def agg(xs: TraversableOnce[T]): Reservoir[T] =
187+
appendAll(monoid.zero(k), xs)
188+
}
189+
190+
class ReservoirSamplingToListAggregator[T](k: Int)(implicit randomSupplier: () => Random)
191+
extends ReservoirSamplingAggregator[T, List[T]](k)(randomSupplier) {
192+
override def present(r: Reservoir[T]): List[T] =
193+
randomSupplier().shuffle(r.reservoir).toList
194+
195+
override def andThenPresent[D](f: List[T] => D): MonoidAggregator[T, Reservoir[T], D] =
196+
new AndThenPresent(this, f)
197+
}
198+
199+
/**
200+
* Monoid that implements [[andThenPresent]] without ruining the optimized behavior of the aggregator.
201+
*/
202+
protected class AndThenPresent[-A, B, C, +D](val agg: MonoidAggregator[A, B, C], f: C => D)
203+
extends MonoidAggregator[A, B, D] {
204+
override val monoid: Monoid[B] = agg.monoid
205+
override def prepare(a: A): B = agg.prepare(a)
206+
override def present(b: B): D = f(agg.present(b))
207+
208+
override def apply(xs: TraversableOnce[A]): D = f(agg(xs))
209+
override def applyOption(xs: TraversableOnce[A]): Option[D] = agg.applyOption(xs).map(f)
210+
override def append(b: B, a: A): B = agg.append(b, a)
211+
override def appendAll(b: B, as: TraversableOnce[A]): B = agg.appendAll(b, as)
212+
override def appendAll(as: TraversableOnce[A]): B = agg.appendAll(as)
213+
}
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
package com.twitter.algebird
2+
3+
import com.twitter.algebird.scalacheck.Distribution._
4+
import org.scalacheck.{Gen, Prop}
5+
6+
object RandomSamplingLaws {
7+
8+
def sampleOneUniformly[T](newSampler: Int => Aggregator[Int, T, Seq[Int]]): Prop = {
9+
val n = 100
10+
11+
"sampleOne" |: forAllSampled(10000, Gen.choose(1, 20))(_ => uniform(n)) { k =>
12+
newSampler(k).andThenPresent(_.head).apply(0 until n)
13+
}
14+
}
15+
16+
def reservoirSizeOne[T](newSampler: Int => Aggregator[Int, T, Seq[Int]]): Prop = {
17+
val n = 100
18+
19+
"reservoirSizeOne" |: forAllSampled(10000)(uniform(n)) {
20+
newSampler(1).andThenPresent(_.head).apply(0 until n)
21+
}
22+
}
23+
24+
def reservoirSizeTwo[T](newSampler: Int => Aggregator[Int, T, Seq[Int]]): Prop = {
25+
val n = 10
26+
val tuples = for {
27+
i <- 0 until n
28+
j <- 0 until n
29+
if i != j
30+
} yield (i, j)
31+
32+
"reservoirSizeTwo" |: forAllSampled(10000)(tuples.map(_ -> 1d).toMap) {
33+
newSampler(2).andThenPresent(xs => (xs(0), xs(1))).apply(0 until n)
34+
}
35+
}
36+
37+
def sampleSpecificItem[T](newSampler: Int => Aggregator[Int, T, Seq[Int]]): Prop = {
38+
val sizeAndIndex: Gen[(Int, Int)] = for {
39+
k <- Gen.choose(1, 10)
40+
i <- Gen.choose(0, k - 1)
41+
} yield (k, i)
42+
43+
val n = 100
44+
45+
"sampleAnyItem" |: forAllSampled(10000, sizeAndIndex)(_ => uniform(n)) { case (k, i) =>
46+
newSampler(k).andThenPresent(_(i)).apply(0 until n)
47+
}
48+
}
49+
50+
def sampleTwoItems[T](newSampler: Int => Aggregator[Int, T, Seq[Int]]): Prop = {
51+
val sizeAndIndexes: Gen[(Int, Int, Int)] = for {
52+
k <- Gen.choose(1, 10)
53+
i <- Gen.choose(0, k - 1)
54+
j <- Gen.choose(0, k - 1)
55+
if i != j
56+
} yield (k, i, j)
57+
58+
val n = 20
59+
60+
"sampleTwoItems" |: forAllSampled(10000, sizeAndIndexes)(_ =>
61+
(for {
62+
i <- 0 until n
63+
j <- 0 until n
64+
if i != j
65+
} yield (i, j)).map(_ -> 1d).toMap
66+
) { case (k, i, j) =>
67+
newSampler(k).andThenPresent(xs => (xs(i), xs(j))).apply(0 until n)
68+
}
69+
}
70+
71+
def randomSamplingDistributions[T](newSampler: Int => MonoidAggregator[Int, T, Seq[Int]]): Prop =
72+
sampleOneUniformly(newSampler) &&
73+
reservoirSizeOne(newSampler) &&
74+
reservoirSizeTwo(newSampler) &&
75+
sampleSpecificItem(newSampler) &&
76+
sampleTwoItems(newSampler)
77+
}

0 commit comments

Comments
 (0)