Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions sampling/reservoir_items_sketch.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,18 +209,18 @@ func (s *ReservoirItemsSketch[T]) EstimateSubsetSum(predicate func(T) bool) (Sam

lowerBoundTrueFraction, err := pseudoHypergeometricLowerBoundOnP(uint64(numSamples), uint64(trueCount), samplingRate)
if err != nil {
return SampleSubsetSummary{}, nil
return SampleSubsetSummary{}, err
}
upperBoundTrueFraction, err := pseudoHypergeometricUpperBoundOnP(uint64(numSamples), uint64(trueCount), samplingRate)
if err != nil {
return SampleSubsetSummary{}, nil
return SampleSubsetSummary{}, err
}
estimatedTrueFraction := (1.0 * float64(trueCount)) / float64(numSamples)
return SampleSubsetSummary{
LowerBound: lowerBoundTrueFraction,
Estimate: estimatedTrueFraction,
UpperBound: upperBoundTrueFraction,
TotalSketchWeight: float64(numSamples),
LowerBound: float64(s.n) * lowerBoundTrueFraction,
Estimate: float64(s.n) * estimatedTrueFraction,
UpperBound: float64(s.n) * upperBoundTrueFraction,
TotalSketchWeight: float64(s.n),
}, nil
}

Expand Down
127 changes: 61 additions & 66 deletions sampling/reservoir_items_sketch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ package sampling
import (
"encoding/binary"
"math"
"math/rand"
"testing"

"github.com/stretchr/testify/assert"

"github.com/apache/datasketches-go/common"
"github.com/apache/datasketches-go/internal"
"github.com/stretchr/testify/assert"
)

func TestNewReservoirItemsSketch(t *testing.T) {
Expand Down Expand Up @@ -200,91 +200,86 @@ func TestReservoirItemsSketchResizeFactorSerialization(t *testing.T) {
}

func TestReservoirItemsSketchEstimateSubsetSum(t *testing.T) {
t.Run("EmptySketch", func(t *testing.T) {
sketch, err := NewReservoirItemsSketch[int64](10)
var (
k = 10
passLB = 0
passUB = 0
)
for trial := 0; trial < 3; trial++ {
sketch, err := NewReservoirItemsSketch[int64](k)
assert.NoError(t, err)

summary, err := sketch.EstimateSubsetSum(func(int64) bool { return true })
// empty sketch
summary, err := sketch.EstimateSubsetSum(func(i int64) bool {
return true
})
assert.NoError(t, err)
assert.Equal(t, 0.0, summary.LowerBound)
assert.Equal(t, 0.0, summary.Estimate)
assert.Equal(t, 0.0, summary.UpperBound)
assert.Equal(t, 0.0, summary.TotalSketchWeight)
})

t.Run("ExactMode", func(t *testing.T) {
sketch, err := NewReservoirItemsSketch[int64](10)
assert.NoError(t, err)
for i := int64(1); i <= 5; i++ {
sketch.Update(i)
// exact mode
itemCount := 0.0
for i := 1; i < k; i++ {
sketch.Update(int64(i))
itemCount += 1.0
}

summary, err := sketch.EstimateSubsetSum(func(v int64) bool { return v%2 == 0 })
summary, err = sketch.EstimateSubsetSum(func(i int64) bool {
return true
})
assert.NoError(t, err)
assert.Equal(t, 2.0, summary.LowerBound)
assert.Equal(t, 2.0, summary.Estimate)
assert.Equal(t, 2.0, summary.UpperBound)
assert.Equal(t, 5.0, summary.TotalSketchWeight)
})
assert.Equal(t, itemCount, summary.Estimate)
assert.Equal(t, itemCount, summary.LowerBound)
assert.Equal(t, itemCount, summary.UpperBound)
assert.Equal(t, itemCount, summary.TotalSketchWeight)

t.Run("EstimationModePredicateNeverMatches", func(t *testing.T) {
rand.Seed(7)
sketch, err := NewReservoirItemsSketch[int64](10)
assert.NoError(t, err)
for i := int64(1); i <= 100; i++ {
sketch.Update(i)
// estimation mode
for i := k; i < (k + 2); i++ {
sketch.Update(int64(i))
itemCount += 1.0
}

summary, err := sketch.EstimateSubsetSum(func(int64) bool { return false })
summary, err = sketch.EstimateSubsetSum(func(i int64) bool {
return true
})
assert.NoError(t, err)
assert.Equal(t, itemCount, summary.Estimate)
assert.Equal(t, itemCount, summary.UpperBound)
assert.Less(t, summary.LowerBound, itemCount)
assert.Equal(t, itemCount, summary.TotalSketchWeight)

summary, err = sketch.EstimateSubsetSum(func(i int64) bool {
return false
})
assert.NoError(t, err)
assert.Equal(t, 0.0, summary.Estimate)
assert.Equal(t, 0.0, summary.LowerBound)
assert.True(t, summary.UpperBound >= 0.0 && summary.UpperBound <= 1.0)
assert.Equal(t, float64(sketch.NumSamples()), summary.TotalSketchWeight)
})

t.Run("EstimationModePredicateAlwaysMatches", func(t *testing.T) {
rand.Seed(11)
sketch, err := NewReservoirItemsSketch[int64](10)
assert.NoError(t, err)
for i := int64(1); i <= 100; i++ {
sketch.Update(i)
assert.Greater(t, summary.UpperBound, 0.0)
assert.Equal(t, itemCount, summary.TotalSketchWeight)

// finally, a non-degenerate predicate
// insert negative items with identical weights, filter for negative weights only
for i := k; i < (k + 2); i++ {
sketch.Update(int64(-i))
itemCount += 1.0
}

summary, err := sketch.EstimateSubsetSum(func(int64) bool { return true })
summary, err = sketch.EstimateSubsetSum(func(i int64) bool {
return i < 0
})
assert.NoError(t, err)
assert.Equal(t, 1.0, summary.Estimate)
assert.Equal(t, 1.0, summary.UpperBound)
assert.True(t, summary.LowerBound >= 0.0 && summary.LowerBound <= 1.0)
assert.Equal(t, float64(sketch.NumSamples()), summary.TotalSketchWeight)
})
assert.GreaterOrEqual(t, summary.Estimate, summary.LowerBound)
assert.LessOrEqual(t, summary.Estimate, summary.UpperBound)

t.Run("EstimationModePredicatePartiallyMatches", func(t *testing.T) {
rand.Seed(23)
sketch, err := NewReservoirItemsSketch[int64](10)
assert.NoError(t, err)
for i := int64(1); i <= 100; i++ {
sketch.Update(i)
if summary.LowerBound < (itemCount / 1.4) {
passLB++
}

samples := sketch.Samples()
trueCount := 0
for _, v := range samples {
if v%2 == 0 {
trueCount++
}
if summary.UpperBound > (itemCount / 2.6) {
passUB++
}
expectedEstimate := float64(trueCount) / float64(len(samples))

summary, err := sketch.EstimateSubsetSum(func(v int64) bool { return v%2 == 0 })
assert.NoError(t, err)
assert.InDelta(t, expectedEstimate, summary.Estimate, 0.0)
assert.True(t, summary.LowerBound >= 0.0 && summary.LowerBound <= 1.0)
assert.True(t, summary.UpperBound >= 0.0 && summary.UpperBound <= 1.0)
assert.True(t, summary.LowerBound <= summary.Estimate)
assert.True(t, summary.Estimate <= summary.UpperBound)
assert.Equal(t, float64(sketch.NumSamples()), summary.TotalSketchWeight)
})
assert.Equal(t, itemCount, summary.TotalSketchWeight)
}
assert.True(t, passLB >= 2 && passUB >= 2)
}

func TestReservoirItemsSketchLegacySerVerEmpty(t *testing.T) {
Expand Down
111 changes: 99 additions & 12 deletions sampling/varopt_items_sketch.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ package sampling

import (
"errors"
"fmt"
"iter"
"math"
"math/rand"
"slices"
"strings"

"github.com/apache/datasketches-go/internal"
)
Expand Down Expand Up @@ -191,12 +193,6 @@ func (s *VarOptItemsSketch[T]) All() iter.Seq[Sample[T]] {
}
}

// inWarmup returns true if the sketch is still in warmup phase (exact mode).
// During warmup, r=0 and we store all items directly in H.
func (s *VarOptItemsSketch[T]) inWarmup() bool {
return s.r == 0
}

// peekMin returns the minimum weight in the H region (heap root).
func (s *VarOptItemsSketch[T]) peekMin() (float64, error) {
if s.h == 0 {
Expand All @@ -223,13 +219,17 @@ func (s *VarOptItemsSketch[T]) update(item T, weight float64, mark bool) error {
return s.updateWarmupPhase(item, weight, mark)
}

minWeight, err := s.peekMin()
if err != nil {
return err
}
var minWeight float64
if s.h != 0 {
var err error
minWeight, err = s.peekMin()
if err != nil {
return err
}

if s.h != 0 && minWeight < s.tau() {
return errors.New("sketch not in valid estimation mode")
if minWeight < s.tau() {
return errors.New("sketch not in valid estimation mode")
}
}

// what tau would be if deletion candidates turn out to be R plus the new item
Expand Down Expand Up @@ -649,3 +649,90 @@ func (s *VarOptItemsSketch[T]) adjustedSize(maxSize, resizeTarget int) int {
}
return resizeTarget
}

// String returns a human-readable summary of this sketch.
func (s *VarOptItemsSketch[T]) String() string {
var sb strings.Builder
sb.WriteString("\n")
sb.WriteString("### VarOptItemsSketch SUMMARY: \n")
sb.WriteString(fmt.Sprintf(" k : %d\n", s.k))
sb.WriteString(fmt.Sprintf(" h : %d\n", s.h))
sb.WriteString(fmt.Sprintf(" r : %d\n", s.r))
sb.WriteString(fmt.Sprintf(" weight_r : %g\n", s.totalWeightR))
sb.WriteString(fmt.Sprintf(" Current size : %d\n", cap(s.data)))
sb.WriteString(fmt.Sprintf(" Resize factor: %v\n", s.rf))
sb.WriteString("### END SKETCH SUMMARY\n")
return sb.String()
}

// EstimateSubsetSum computes an estimated subset sum from the entire stream for objects matching a given
// predicate. Provides a lower bound, estimate, and upper bound using a target of 2 standard deviations.
//
// NOTE: This is technically a heuristic method, and tries to err on the conservative side.
//
// predicate: A predicate to use when identifying items.
// Returns a summary object containing the estimate, upper and lower bounds, and the total sketch weight.
func (s *VarOptItemsSketch[T]) EstimateSubsetSum(predicate func(T) bool) (SampleSubsetSummary, error) {
if s.n == 0 {
return SampleSubsetSummary{}, nil
}

var (
weightSumInH = 0.0
trueWeightInH = 0.0
idx = 0
)
for idx < s.h {
weight := s.weights[idx]

weightSumInH += weight

if predicate(s.data[idx]) {
trueWeightInH += weight
}

idx++
}

// if only heavy items, we have an exact answer
if s.r == 0 {
return SampleSubsetSummary{
LowerBound: trueWeightInH,
Estimate: trueWeightInH,
UpperBound: trueWeightInH,
TotalSketchWeight: trueWeightInH,
}, nil
}

numSampled := s.n - int64(s.h)
effectiveSamplingRate := float64(s.r) / float64(numSampled)
if effectiveSamplingRate < 0 || effectiveSamplingRate > 1.0 {
return SampleSubsetSummary{}, errors.New("invalid sampling rate outside [0.0, 1.0]")
}

trueRCount := 0
idx++ // skip the gap
for idx < (s.k + 1) {
if predicate(s.data[idx]) {
trueRCount++
}

idx++
}

lowerBoundTrueFraction, err := pseudoHypergeometricLowerBoundOnP(uint64(s.r), uint64(trueRCount), effectiveSamplingRate)
if err != nil {
return SampleSubsetSummary{}, err
}
upperTrueFraction, err := pseudoHypergeometricUpperBoundOnP(uint64(s.r), uint64(trueRCount), effectiveSamplingRate)
if err != nil {
return SampleSubsetSummary{}, err
}
estimatedTrueFraction := float64(trueRCount) / float64(s.r)
return SampleSubsetSummary{
LowerBound: trueWeightInH + s.totalWeightR*lowerBoundTrueFraction,
Estimate: trueWeightInH + s.totalWeightR*estimatedTrueFraction,
UpperBound: trueWeightInH + s.totalWeightR*upperTrueFraction,
TotalSketchWeight: weightSumInH + s.totalWeightR,
}, nil
}
Loading
Loading