Skip to content

Commit 7d1cfdf

Browse files
committed
Optimize sampling from IndexedSeq
Indexed sequences allow us to skip over items without examining each one.
1 parent 431d5b7 commit 7d1cfdf

File tree

3 files changed

+72
-13
lines changed

3 files changed

+72
-13
lines changed

algebird-benchmark/src/main/scala/com/twitter/algebird/benchmark/ReservoirSamplingBenchmark.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ class ReservoirSamplingBenchmark {
3636
def timeAlgorithmL(state: BenchmarkState, bh: Blackhole): Unit =
3737
bh.consume(new ReservoirSamplingToListAggregator[Int](state.samples).apply(0 until state.collectionSize))
3838

39+
@Benchmark
40+
def timeAlgorithmLSeq(state: BenchmarkState, bh: Blackhole): Unit =
41+
bh.consume(new ReservoirSamplingToListAggregator[Int](state.samples).apply((0 until state.collectionSize).asInstanceOf[Seq[Int]]))
42+
3943
@Benchmark
4044
def timePriorityQeueue(state: BenchmarkState, bh: Blackhole): Unit =
4145
bh.consume(prioQueueSampler(state.samples).apply(0 until state.collectionSize))

algebird-core/src/main/scala/com/twitter/algebird/mutable/ReservoirSampling.scala

Lines changed: 58 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import scala.util.Random
1414
* the element type
1515
*/
1616
sealed class Reservoir[T](val capacity: Int) {
17-
var reservoir: mutable.Buffer[T] = mutable.Buffer()
17+
var reservoir: mutable.ArrayBuffer[T] = new mutable.ArrayBuffer
1818

1919
// When the reservoir is full, w is the threshold for accepting an element into the reservoir, and
2020
// the following invariant holds: The maximum score of the elements in the reservoir is w,
@@ -52,6 +52,13 @@ sealed class Reservoir[T](val capacity: Int) {
5252
}
5353
}
5454

55+
// The number of items to skip before accepting the next item is geometrically distributed
56+
// with probability of success w / prior. The prior will be 1 when adding to a single reservoir,
57+
// but when merging reservoirs it will be the threshold of the reservoir being pulled from,
58+
// and in this case we require that w < prior.
59+
private def nextAcceptTime(rng: Random, prior: Double = 1.0): Int =
60+
(-rng.self.nextExponential / Math.log1p(-w / prior)).toInt
61+
5562
/**
5663
* Add multiple elements to the reservoir.
5764
* @param xs
@@ -64,26 +71,55 @@ sealed class Reservoir[T](val capacity: Int) {
6471
* @return
6572
* this reservoir
6673
*/
67-
def append(xs: TraversableOnce[T], rng: Random, prior: Double = 1): Reservoir[T] = {
68-
// The number of items to skip before accepting the next item is geometrically distributed
69-
// with probability of success w / prior. The prior will be 1 when adding to a single reservoir,
70-
// but when merging reservoirs it will be the threshold of the reservoir being pulled from,
71-
// and in this case we require that w < prior.
72-
def nextAcceptTime = (-rng.self.nextExponential / Math.log1p(-w / prior)).toInt
73-
74-
var skip = if (isFull) nextAcceptTime else 0
74+
def append(xs: TraversableOnce[T], rng: Random): Reservoir[T] = {
75+
var skip = if (isFull) nextAcceptTime(rng) else 0
7576
for (x <- xs) {
7677
if (!isFull) {
7778
// keep adding while reservoir is not full
7879
accept(x, rng)
7980
if (isFull) {
80-
skip = nextAcceptTime
81+
skip = nextAcceptTime(rng)
8182
}
8283
} else if (skip > 0) {
8384
skip -= 1
8485
} else {
8586
accept(x, rng)
86-
skip = nextAcceptTime
87+
skip = nextAcceptTime(rng)
88+
}
89+
}
90+
this
91+
}
92+
93+
/**
94+
* Add multiple elements to the reservoir. This overload is optimized for indexed sequences, where we can
95+
* skip over multiple indexes without accessing the elements.
96+
*
97+
* @param xs
98+
* the elements to add
99+
* @param rng
100+
* the random source
101+
* @param prior
102+
* the threshold of the elements being added, such that the added element's value is distributed as
103+
* <pre>U[0, prior]</pre>
104+
* @return
105+
* this reservoir
106+
*/
107+
def append(xs: IndexedSeq[T], rng: Random, prior: Double): Reservoir[T] = {
108+
var i = xs.size.min(capacity - size)
109+
for (j <- 0 until i) {
110+
accept(xs(j), rng)
111+
}
112+
assert(isFull)
113+
114+
val end = xs.size
115+
i -= 1
116+
while (i >= 0 && i < end) {
117+
i += 1 + nextAcceptTime(rng, prior)
118+
// the addition can overflow, in which case i < 0
119+
if (i >= 0 && i < end) {
120+
// element enters the reservoir
121+
reservoir(rng.nextInt(capacity)) = xs(i)
122+
w *= Math.pow(rng.nextDouble, kInv)
87123
}
88124
}
89125
this
@@ -147,7 +183,7 @@ class ReservoirMonoid[T](implicit val randomSupplier: () => Random) extends Mono
147183
s2.reservoir(i) = s2.reservoir.head
148184
s1.append(s2.reservoir.drop(1), rng, s2.w)
149185
} else {
150-
s1.append(s2.reservoir, rng)
186+
s1.append(s2.reservoir, rng, 1.0)
151187
}
152188
}
153189
}
@@ -157,6 +193,10 @@ class ReservoirMonoid[T](implicit val randomSupplier: () => Random) extends Mono
157193
* reservoir is mutable, it is a good idea to copy the result to an immutable view before using it, as is done
158194
* by [[ReservoirSamplingToListAggregator]].
159195
*
196+
* The aggregator defines operations for [[IndexedSeq]]s that allow for more efficient aggregation, however
197+
* care must be taken with methods such as [[composePrepare()]] which return a regular [[MonoidAggregator]]
198+
* that loses this optimized behavior.
199+
*
160200
* @param k
161201
* the number of elements to sample
162202
* @param randomSupplier
@@ -172,6 +212,7 @@ abstract class ReservoirSamplingAggregator[T, +C](k: Int)(implicit val randomSup
172212
override def prepare(x: T): Reservoir[T] = monoid.build(k, x)
173213

174214
override def apply(xs: TraversableOnce[T]): C = present(agg(xs))
215+
def apply(xs: IndexedSeq[T]): C = present(agg(xs))
175216

176217
override def applyOption(inputs: TraversableOnce[T]): Option[C] =
177218
if (inputs.isEmpty) None else Some(apply(inputs))
@@ -180,11 +221,16 @@ abstract class ReservoirSamplingAggregator[T, +C](k: Int)(implicit val randomSup
180221

181222
override def appendAll(r: Reservoir[T], xs: TraversableOnce[T]): Reservoir[T] =
182223
r.append(xs, randomSupplier())
224+
def appendAll(r: Reservoir[T], xs: IndexedSeq[T]): Reservoir[T] =
225+
r.append(xs, randomSupplier(), 1.0)
183226

184227
override def appendAll(xs: TraversableOnce[T]): Reservoir[T] = agg(xs)
228+
def appendAll(xs: IndexedSeq[T]): Reservoir[T] = agg(xs)
185229

186230
private def agg(xs: TraversableOnce[T]): Reservoir[T] =
187231
appendAll(monoid.zero(k), xs)
232+
private def agg(xs: IndexedSeq[T]): Reservoir[T] =
233+
appendAll(monoid.zero(k), xs)
188234
}
189235

190236
class ReservoirSamplingToListAggregator[T](k: Int)(implicit randomSupplier: () => Random)

algebird-test/src/test/scala/com/twitter/algebird/mutable/ReservoirSamplingTest.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
package com.twitter.algebird.mutable
22

3-
import com.twitter.algebird.{Aggregator, CheckProperties, Preparer}
43
import com.twitter.algebird.RandomSamplingLaws._
4+
import com.twitter.algebird.scalacheck.Distribution.{forAllSampled, uniform}
5+
import com.twitter.algebird.{Aggregator, CheckProperties, Preparer}
6+
import org.scalacheck.Gen
57

68
import scala.util.Random
79

@@ -23,4 +25,11 @@ class ReservoirSamplingTest extends CheckProperties {
2325
property("reservoir sampling with priority queue works") {
2426
randomSamplingDistributions(prioQueueSampler)
2527
}
28+
29+
property("sampling from non-indexed Seq") {
30+
val n = 100
31+
"sampleList" |: forAllSampled(10000, Gen.choose(1, 20))(_ => uniform(n)) { k =>
32+
new ReservoirSamplingToListAggregator[Int](k).apply((0 until n).asInstanceOf[Seq[Int]]).head
33+
}
34+
}
2635
}

0 commit comments

Comments
 (0)