From 23da4cb9f757ebbc5c4846027083e89b1330096d Mon Sep 17 00:00:00 2001 From: lani_karrot Date: Sun, 15 Feb 2026 23:05:39 +0900 Subject: [PATCH] feat: EstimateSubsetSum on varoptitemssketch --- sampling/reservoir_items_sketch.go | 12 +-- sampling/reservoir_items_sketch_test.go | 127 ++++++++++++------------ sampling/varopt_items_sketch.go | 111 ++++++++++++++++++--- sampling/varopt_items_sketch_test.go | 78 +++++++++++++++ 4 files changed, 244 insertions(+), 84 deletions(-) diff --git a/sampling/reservoir_items_sketch.go b/sampling/reservoir_items_sketch.go index 526f322..c7e1905 100644 --- a/sampling/reservoir_items_sketch.go +++ b/sampling/reservoir_items_sketch.go @@ -209,18 +209,18 @@ func (s *ReservoirItemsSketch[T]) EstimateSubsetSum(predicate func(T) bool) (Sam lowerBoundTrueFraction, err := pseudoHypergeometricLowerBoundOnP(uint64(numSamples), uint64(trueCount), samplingRate) if err != nil { - return SampleSubsetSummary{}, nil + return SampleSubsetSummary{}, err } upperBoundTrueFraction, err := pseudoHypergeometricUpperBoundOnP(uint64(numSamples), uint64(trueCount), samplingRate) if err != nil { - return SampleSubsetSummary{}, nil + return SampleSubsetSummary{}, err } estimatedTrueFraction := (1.0 * float64(trueCount)) / float64(numSamples) return SampleSubsetSummary{ - LowerBound: lowerBoundTrueFraction, - Estimate: estimatedTrueFraction, - UpperBound: upperBoundTrueFraction, - TotalSketchWeight: float64(numSamples), + LowerBound: float64(s.n) * lowerBoundTrueFraction, + Estimate: float64(s.n) * estimatedTrueFraction, + UpperBound: float64(s.n) * upperBoundTrueFraction, + TotalSketchWeight: float64(s.n), }, nil } diff --git a/sampling/reservoir_items_sketch_test.go b/sampling/reservoir_items_sketch_test.go index d223377..da571c9 100644 --- a/sampling/reservoir_items_sketch_test.go +++ b/sampling/reservoir_items_sketch_test.go @@ -20,12 +20,12 @@ package sampling import ( "encoding/binary" "math" - "math/rand" "testing" + "github.com/stretchr/testify/assert" + "github.com/apache/datasketches-go/common" "github.com/apache/datasketches-go/internal" - "github.com/stretchr/testify/assert" ) func TestNewReservoirItemsSketch(t *testing.T) { @@ -200,91 +200,86 @@ func TestReservoirItemsSketchResizeFactorSerialization(t *testing.T) { } func TestReservoirItemsSketchEstimateSubsetSum(t *testing.T) { - t.Run("EmptySketch", func(t *testing.T) { - sketch, err := NewReservoirItemsSketch[int64](10) + var ( + k = 10 + passLB = 0 + passUB = 0 + ) + for trial := 0; trial < 3; trial++ { + sketch, err := NewReservoirItemsSketch[int64](k) assert.NoError(t, err) - summary, err := sketch.EstimateSubsetSum(func(int64) bool { return true }) + // empty sketch + summary, err := sketch.EstimateSubsetSum(func(i int64) bool { + return true + }) assert.NoError(t, err) - assert.Equal(t, 0.0, summary.LowerBound) assert.Equal(t, 0.0, summary.Estimate) - assert.Equal(t, 0.0, summary.UpperBound) assert.Equal(t, 0.0, summary.TotalSketchWeight) - }) - t.Run("ExactMode", func(t *testing.T) { - sketch, err := NewReservoirItemsSketch[int64](10) - assert.NoError(t, err) - for i := int64(1); i <= 5; i++ { - sketch.Update(i) + // exact mode + itemCount := 0.0 + for i := 1; i < k; i++ { + sketch.Update(int64(i)) + itemCount += 1.0 } - summary, err := sketch.EstimateSubsetSum(func(v int64) bool { return v%2 == 0 }) + summary, err = sketch.EstimateSubsetSum(func(i int64) bool { + return true + }) assert.NoError(t, err) - assert.Equal(t, 2.0, summary.LowerBound) - assert.Equal(t, 2.0, summary.Estimate) - assert.Equal(t, 2.0, summary.UpperBound) - assert.Equal(t, 5.0, summary.TotalSketchWeight) - }) + assert.Equal(t, itemCount, summary.Estimate) + assert.Equal(t, itemCount, summary.LowerBound) + assert.Equal(t, itemCount, summary.UpperBound) + assert.Equal(t, itemCount, summary.TotalSketchWeight) - t.Run("EstimationModePredicateNeverMatches", func(t *testing.T) { - rand.Seed(7) - sketch, err := NewReservoirItemsSketch[int64](10) - assert.NoError(t, err) - for i := int64(1); i <= 100; i++ { - sketch.Update(i) + // estimation mode + for i := k; i < (k + 2); i++ { + sketch.Update(int64(i)) + itemCount += 1.0 } - summary, err := sketch.EstimateSubsetSum(func(int64) bool { return false }) + summary, err = sketch.EstimateSubsetSum(func(i int64) bool { + return true + }) + assert.NoError(t, err) + assert.Equal(t, itemCount, summary.Estimate) + assert.Equal(t, itemCount, summary.UpperBound) + assert.Less(t, summary.LowerBound, itemCount) + assert.Equal(t, itemCount, summary.TotalSketchWeight) + + summary, err = sketch.EstimateSubsetSum(func(i int64) bool { + return false + }) assert.NoError(t, err) assert.Equal(t, 0.0, summary.Estimate) assert.Equal(t, 0.0, summary.LowerBound) - assert.True(t, summary.UpperBound >= 0.0 && summary.UpperBound <= 1.0) - assert.Equal(t, float64(sketch.NumSamples()), summary.TotalSketchWeight) - }) - - t.Run("EstimationModePredicateAlwaysMatches", func(t *testing.T) { - rand.Seed(11) - sketch, err := NewReservoirItemsSketch[int64](10) - assert.NoError(t, err) - for i := int64(1); i <= 100; i++ { - sketch.Update(i) + assert.Greater(t, summary.UpperBound, 0.0) + assert.Equal(t, itemCount, summary.TotalSketchWeight) + + // finally, a non-degenerate predicate + // insert negative items with identical weights, filter for negative weights only + for i := k; i < (k + 2); i++ { + sketch.Update(int64(-i)) + itemCount += 1.0 } - summary, err := sketch.EstimateSubsetSum(func(int64) bool { return true }) + summary, err = sketch.EstimateSubsetSum(func(i int64) bool { + return i < 0 + }) assert.NoError(t, err) - assert.Equal(t, 1.0, summary.Estimate) - assert.Equal(t, 1.0, summary.UpperBound) - assert.True(t, summary.LowerBound >= 0.0 && summary.LowerBound <= 1.0) - assert.Equal(t, float64(sketch.NumSamples()), summary.TotalSketchWeight) - }) + assert.GreaterOrEqual(t, summary.Estimate, summary.LowerBound) + assert.LessOrEqual(t, summary.Estimate, summary.UpperBound) - t.Run("EstimationModePredicatePartiallyMatches", func(t *testing.T) { - rand.Seed(23) - sketch, err := NewReservoirItemsSketch[int64](10) - assert.NoError(t, err) - for i := int64(1); i <= 100; i++ { - sketch.Update(i) + if summary.LowerBound < (itemCount / 1.4) { + passLB++ } - - samples := sketch.Samples() - trueCount := 0 - for _, v := range samples { - if v%2 == 0 { - trueCount++ - } + if summary.UpperBound > (itemCount / 2.6) { + passUB++ } - expectedEstimate := float64(trueCount) / float64(len(samples)) - - summary, err := sketch.EstimateSubsetSum(func(v int64) bool { return v%2 == 0 }) - assert.NoError(t, err) - assert.InDelta(t, expectedEstimate, summary.Estimate, 0.0) - assert.True(t, summary.LowerBound >= 0.0 && summary.LowerBound <= 1.0) - assert.True(t, summary.UpperBound >= 0.0 && summary.UpperBound <= 1.0) - assert.True(t, summary.LowerBound <= summary.Estimate) - assert.True(t, summary.Estimate <= summary.UpperBound) - assert.Equal(t, float64(sketch.NumSamples()), summary.TotalSketchWeight) - }) + assert.Equal(t, itemCount, summary.TotalSketchWeight) + } + assert.True(t, passLB >= 2 && passUB >= 2) } func TestReservoirItemsSketchLegacySerVerEmpty(t *testing.T) { diff --git a/sampling/varopt_items_sketch.go b/sampling/varopt_items_sketch.go index c41664c..55cd16a 100644 --- a/sampling/varopt_items_sketch.go +++ b/sampling/varopt_items_sketch.go @@ -19,10 +19,12 @@ package sampling import ( "errors" + "fmt" "iter" "math" "math/rand" "slices" + "strings" "github.com/apache/datasketches-go/internal" ) @@ -191,12 +193,6 @@ func (s *VarOptItemsSketch[T]) All() iter.Seq[Sample[T]] { } } -// inWarmup returns true if the sketch is still in warmup phase (exact mode). -// During warmup, r=0 and we store all items directly in H. -func (s *VarOptItemsSketch[T]) inWarmup() bool { - return s.r == 0 -} - // peekMin returns the minimum weight in the H region (heap root). func (s *VarOptItemsSketch[T]) peekMin() (float64, error) { if s.h == 0 { @@ -223,13 +219,17 @@ func (s *VarOptItemsSketch[T]) update(item T, weight float64, mark bool) error { return s.updateWarmupPhase(item, weight, mark) } - minWeight, err := s.peekMin() - if err != nil { - return err - } + var minWeight float64 + if s.h != 0 { + var err error + minWeight, err = s.peekMin() + if err != nil { + return err + } - if s.h != 0 && minWeight < s.tau() { - return errors.New("sketch not in valid estimation mode") + if minWeight < s.tau() { + return errors.New("sketch not in valid estimation mode") + } } // what tau would be if deletion candidates turn out to be R plus the new item @@ -649,3 +649,90 @@ func (s *VarOptItemsSketch[T]) adjustedSize(maxSize, resizeTarget int) int { } return resizeTarget } + +// String returns a human-readable summary of this sketch. +func (s *VarOptItemsSketch[T]) String() string { + var sb strings.Builder + sb.WriteString("\n") + sb.WriteString("### VarOptItemsSketch SUMMARY: \n") + sb.WriteString(fmt.Sprintf(" k : %d\n", s.k)) + sb.WriteString(fmt.Sprintf(" h : %d\n", s.h)) + sb.WriteString(fmt.Sprintf(" r : %d\n", s.r)) + sb.WriteString(fmt.Sprintf(" weight_r : %g\n", s.totalWeightR)) + sb.WriteString(fmt.Sprintf(" Current size : %d\n", cap(s.data))) + sb.WriteString(fmt.Sprintf(" Resize factor: %v\n", s.rf)) + sb.WriteString("### END SKETCH SUMMARY\n") + return sb.String() +} + +// EstimateSubsetSum computes an estimated subset sum from the entire stream for objects matching a given +// predicate. Provides a lower bound, estimate, and upper bound using a target of 2 standard deviations. +// +// NOTE: This is technically a heuristic method, and tries to err on the conservative side. +// +// predicate: A predicate to use when identifying items. +// Returns a summary object containing the estimate, upper and lower bounds, and the total sketch weight. +func (s *VarOptItemsSketch[T]) EstimateSubsetSum(predicate func(T) bool) (SampleSubsetSummary, error) { + if s.n == 0 { + return SampleSubsetSummary{}, nil + } + + var ( + weightSumInH = 0.0 + trueWeightInH = 0.0 + idx = 0 + ) + for idx < s.h { + weight := s.weights[idx] + + weightSumInH += weight + + if predicate(s.data[idx]) { + trueWeightInH += weight + } + + idx++ + } + + // if only heavy items, we have an exact answer + if s.r == 0 { + return SampleSubsetSummary{ + LowerBound: trueWeightInH, + Estimate: trueWeightInH, + UpperBound: trueWeightInH, + TotalSketchWeight: trueWeightInH, + }, nil + } + + numSampled := s.n - int64(s.h) + effectiveSamplingRate := float64(s.r) / float64(numSampled) + if effectiveSamplingRate < 0 || effectiveSamplingRate > 1.0 { + return SampleSubsetSummary{}, errors.New("invalid sampling rate outside [0.0, 1.0]") + } + + trueRCount := 0 + idx++ // skip the gap + for idx < (s.k + 1) { + if predicate(s.data[idx]) { + trueRCount++ + } + + idx++ + } + + lowerBoundTrueFraction, err := pseudoHypergeometricLowerBoundOnP(uint64(s.r), uint64(trueRCount), effectiveSamplingRate) + if err != nil { + return SampleSubsetSummary{}, err + } + upperTrueFraction, err := pseudoHypergeometricUpperBoundOnP(uint64(s.r), uint64(trueRCount), effectiveSamplingRate) + if err != nil { + return SampleSubsetSummary{}, err + } + estimatedTrueFraction := float64(trueRCount) / float64(s.r) + return SampleSubsetSummary{ + LowerBound: trueWeightInH + s.totalWeightR*lowerBoundTrueFraction, + Estimate: trueWeightInH + s.totalWeightR*estimatedTrueFraction, + UpperBound: trueWeightInH + s.totalWeightR*upperTrueFraction, + TotalSketchWeight: weightSumInH + s.totalWeightR, + }, nil +} diff --git a/sampling/varopt_items_sketch_test.go b/sampling/varopt_items_sketch_test.go index 0f9ca04..f0a24a9 100644 --- a/sampling/varopt_items_sketch_test.go +++ b/sampling/varopt_items_sketch_test.go @@ -307,3 +307,81 @@ func TestVarOptItemsSketch_Update(t *testing.T) { assert.InDelta(t, weightRatio, 1.0, 1e-10) }) } + +func TestVarOptItemsSketch_EstimateSubsetSum(t *testing.T) { + k := 10 + sketch, err := NewVarOptItemsSketch[int64](uint(k)) + assert.NoError(t, err) + + // empty sketch + summary, err := sketch.EstimateSubsetSum(func(i int64) bool { + return true + }) + assert.NoError(t, err) + assert.Equal(t, 0.0, summary.Estimate) + assert.Equal(t, 0.0, summary.TotalSketchWeight) + + // exact mode + weightSum := 0.0 + for i := 1; i < k; i++ { + err := sketch.Update(int64(i), float64(i)) + assert.NoError(t, err) + + weightSum += float64(i) + } + + summary, err = sketch.EstimateSubsetSum(func(i int64) bool { + return true + }) + assert.NoError(t, err) + assert.Equal(t, weightSum, summary.Estimate) + assert.Equal(t, weightSum, summary.LowerBound) + assert.Equal(t, weightSum, summary.UpperBound) + assert.Equal(t, weightSum, summary.TotalSketchWeight) + + // estimation mode + for i := k; i < k+2; i++ { + err = sketch.Update(int64(i), float64(i)) + assert.NoError(t, err) + + weightSum += float64(i) + } + + summary, err = sketch.EstimateSubsetSum(func(i int64) bool { + return true + }) + assert.NoError(t, err) + assert.Equal(t, weightSum, summary.Estimate) + assert.Equal(t, weightSum, summary.UpperBound) + assert.Less(t, summary.LowerBound, weightSum) + assert.Equal(t, weightSum, summary.TotalSketchWeight) + + summary, err = sketch.EstimateSubsetSum(func(i int64) bool { + return false + }) + assert.NoError(t, err) + assert.Equal(t, 0.0, summary.Estimate) + assert.Equal(t, 0.0, summary.LowerBound) + assert.Greater(t, summary.UpperBound, 0.0) + assert.Equal(t, weightSum, summary.TotalSketchWeight) + + // finally, a non-degenerate predicate + // insert negative items with identical weights, filter for negative weights only + for i := 1; i < k+2; i++ { + err := sketch.Update(int64(-i), float64(i)) + assert.NoError(t, err) + + weightSum += float64(i) + } + + summary, err = sketch.EstimateSubsetSum(func(i int64) bool { + return i < 0 + }) + assert.NoError(t, err) + assert.GreaterOrEqual(t, summary.Estimate, summary.LowerBound) + assert.LessOrEqual(t, summary.Estimate, summary.UpperBound) + + assert.Less(t, summary.LowerBound, weightSum/1.4) + assert.Greater(t, summary.UpperBound, weightSum/2.6) + assert.Equal(t, weightSum, summary.TotalSketchWeight) +}