@@ -14,7 +14,7 @@ import scala.util.Random
14
14
* the element type
15
15
*/
16
16
sealed class Reservoir [T ](val capacity : Int ) {
17
- var reservoir : mutable.Buffer [T ] = mutable.Buffer ()
17
+ var reservoir : mutable.ArrayBuffer [T ] = new mutable.ArrayBuffer
18
18
19
19
// When the reservoir is full, w is the threshold for accepting an element into the reservoir, and
20
20
// the following invariant holds: The maximum score of the elements in the reservoir is w,
@@ -52,6 +52,13 @@ sealed class Reservoir[T](val capacity: Int) {
52
52
}
53
53
}
54
54
55
+ // The number of items to skip before accepting the next item is geometrically distributed
56
+ // with probability of success w / prior. The prior will be 1 when adding to a single reservoir,
57
+ // but when merging reservoirs it will be the threshold of the reservoir being pulled from,
58
+ // and in this case we require that w < prior.
59
+ private def nextAcceptTime (rng : Random , prior : Double = 1.0 ): Int =
60
+ (- rng.self.nextExponential / Math .log1p(- w / prior)).toInt
61
+
55
62
/**
56
63
* Add multiple elements to the reservoir.
57
64
* @param xs
@@ -64,26 +71,55 @@ sealed class Reservoir[T](val capacity: Int) {
64
71
* @return
65
72
* this reservoir
66
73
*/
67
- def append (xs : TraversableOnce [T ], rng : Random , prior : Double = 1 ): Reservoir [T ] = {
68
- // The number of items to skip before accepting the next item is geometrically distributed
69
- // with probability of success w / prior. The prior will be 1 when adding to a single reservoir,
70
- // but when merging reservoirs it will be the threshold of the reservoir being pulled from,
71
- // and in this case we require that w < prior.
72
- def nextAcceptTime = (- rng.self.nextExponential / Math .log1p(- w / prior)).toInt
73
-
74
- var skip = if (isFull) nextAcceptTime else 0
74
+ def append (xs : TraversableOnce [T ], rng : Random ): Reservoir [T ] = {
75
+ var skip = if (isFull) nextAcceptTime(rng) else 0
75
76
for (x <- xs) {
76
77
if (! isFull) {
77
78
// keep adding while reservoir is not full
78
79
accept(x, rng)
79
80
if (isFull) {
80
- skip = nextAcceptTime
81
+ skip = nextAcceptTime(rng)
81
82
}
82
83
} else if (skip > 0 ) {
83
84
skip -= 1
84
85
} else {
85
86
accept(x, rng)
86
- skip = nextAcceptTime
87
+ skip = nextAcceptTime(rng)
88
+ }
89
+ }
90
+ this
91
+ }
92
+
93
+ /**
94
+ * Add multiple elements to the reservoir. This overload is optimized for indexed sequences, where we can
95
+ * skip over multiple indexes without accessing the elements.
96
+ *
97
+ * @param xs
98
+ * the elements to add
99
+ * @param rng
100
+ * the random source
101
+ * @param prior
102
+ * the threshold of the elements being added, such that the added element's value is distributed as
103
+ * <pre>U[0, prior]</pre>
104
+ * @return
105
+ * this reservoir
106
+ */
107
+ def append (xs : IndexedSeq [T ], rng : Random , prior : Double ): Reservoir [T ] = {
108
+ var i = xs.size.min(capacity - size)
109
+ for (j <- 0 until i) {
110
+ accept(xs(j), rng)
111
+ }
112
+ assert(isFull)
113
+
114
+ val end = xs.size
115
+ i -= 1
116
+ while (i >= 0 && i < end) {
117
+ i += 1 + nextAcceptTime(rng, prior)
118
+ // the addition can overflow, in which case i < 0
119
+ if (i >= 0 && i < end) {
120
+ // element enters the reservoir
121
+ reservoir(rng.nextInt(capacity)) = xs(i)
122
+ w *= Math .pow(rng.nextDouble, kInv)
87
123
}
88
124
}
89
125
this
@@ -147,7 +183,7 @@ class ReservoirMonoid[T](implicit val randomSupplier: () => Random) extends Mono
147
183
s2.reservoir(i) = s2.reservoir.head
148
184
s1.append(s2.reservoir.drop(1 ), rng, s2.w)
149
185
} else {
150
- s1.append(s2.reservoir, rng)
186
+ s1.append(s2.reservoir, rng, 1.0 )
151
187
}
152
188
}
153
189
}
@@ -157,6 +193,10 @@ class ReservoirMonoid[T](implicit val randomSupplier: () => Random) extends Mono
157
193
* reservoir is mutable, it is a good idea to copy the result to an immutable view before using it, as is done
158
194
* by [[ReservoirSamplingToListAggregator ]].
159
195
*
196
+ * The aggregator defines operations for [[IndexedSeq ]]s that allow for more efficient aggregation, however
197
+ * care must be taken with methods such as [[composePrepare() ]] which return a regular [[MonoidAggregator ]]
198
+ * that loses this optimized behavior.
199
+ *
160
200
* @param k
161
201
* the number of elements to sample
162
202
* @param randomSupplier
@@ -172,6 +212,7 @@ abstract class ReservoirSamplingAggregator[T, +C](k: Int)(implicit val randomSup
172
212
override def prepare (x : T ): Reservoir [T ] = monoid.build(k, x)
173
213
174
214
override def apply (xs : TraversableOnce [T ]): C = present(agg(xs))
215
+ def apply (xs : IndexedSeq [T ]): C = present(agg(xs))
175
216
176
217
override def applyOption (inputs : TraversableOnce [T ]): Option [C ] =
177
218
if (inputs.isEmpty) None else Some (apply(inputs))
@@ -180,11 +221,16 @@ abstract class ReservoirSamplingAggregator[T, +C](k: Int)(implicit val randomSup
180
221
181
222
override def appendAll (r : Reservoir [T ], xs : TraversableOnce [T ]): Reservoir [T ] =
182
223
r.append(xs, randomSupplier())
224
+ def appendAll (r : Reservoir [T ], xs : IndexedSeq [T ]): Reservoir [T ] =
225
+ r.append(xs, randomSupplier(), 1.0 )
183
226
184
227
override def appendAll (xs : TraversableOnce [T ]): Reservoir [T ] = agg(xs)
228
+ def appendAll (xs : IndexedSeq [T ]): Reservoir [T ] = agg(xs)
185
229
186
230
private def agg (xs : TraversableOnce [T ]): Reservoir [T ] =
187
231
appendAll(monoid.zero(k), xs)
232
+ private def agg (xs : IndexedSeq [T ]): Reservoir [T ] =
233
+ appendAll(monoid.zero(k), xs)
188
234
}
189
235
190
236
class ReservoirSamplingToListAggregator [T ](k : Int )(implicit randomSupplier : () => Random )
0 commit comments