From 75c9febbc13d8fb272ee900199b8b7052a5a6919 Mon Sep 17 00:00:00 2001 From: Fengzdadi <453788063@qq.com> Date: Sun, 8 Feb 2026 14:41:51 -0500 Subject: [PATCH 1/2] sampling: align resize factor behavior with Java semantics --- sampling/reservoir_items_sketch.go | 79 ++++++++++++++++---- sampling/reservoir_items_sketch_test.go | 99 ++++++++++++++++++++++--- sampling/varopt_items_sketch.go | 10 ++- sampling/varopt_items_sketch_test.go | 5 ++ 4 files changed, 163 insertions(+), 30 deletions(-) diff --git a/sampling/reservoir_items_sketch.go b/sampling/reservoir_items_sketch.go index a88a14b..d7e84ca 100644 --- a/sampling/reservoir_items_sketch.go +++ b/sampling/reservoir_items_sketch.go @@ -21,9 +21,7 @@ import ( "encoding/binary" "errors" "fmt" - "math" "math/rand" - "slices" "strings" "github.com/apache/datasketches-go/common" @@ -44,11 +42,35 @@ const ( defaultResizeFactor = ResizeX8 minK = 2 + maxItemsSeen = int64(0xFFFFFFFFFFFF) // smallest sampling array allocated: 16 minLgArrItems = 4 ) +func resizeFactorLg(rf ResizeFactor) (int, error) { + switch rf { + case ResizeX1: + return 0, nil + case ResizeX2: + return 1, nil + case ResizeX4: + return 2, nil + case ResizeX8: + return 3, nil + default: + return 0, errors.New("unsupported resize factor") + } +} + +func mustResizeFactorLg(rf ResizeFactor) int { + lgRf, err := resizeFactorLg(rf) + if err != nil { + panic(err) + } + return lgRf +} + // ReservoirItemsSketch provides a uniform random sample of items // from a stream of unknown size using the reservoir sampling algorithm. // @@ -93,9 +115,14 @@ func NewReservoirItemsSketch[T any]( opt(options) } + lgRf, err := resizeFactorLg(options.resizeFactor) + if err != nil { + return nil, err + } + ceilingLgK, _ := internal.ExactLog2(common.CeilingPowerOf2(k)) initialLgSize := startingSubMultiple( - ceilingLgK, int(math.Log2(float64(options.resizeFactor))), minLgArrItems, + ceilingLgK, lgRf, minLgArrItems, ) return &ReservoirItemsSketch[T]{ k: k, @@ -107,6 +134,10 @@ func NewReservoirItemsSketch[T any]( // Update adds an item to the sketch using reservoir sampling algorithm. func (s *ReservoirItemsSketch[T]) Update(item T) { + if s.n == maxItemsSeen { + panic(fmt.Sprintf("sketch has exceeded capacity for total items seen: %d", maxItemsSeen)) + } + if s.n < int64(s.k) { // Initial phase: store all items until reservoir is full if s.n >= int64(cap(s.data)) { @@ -114,19 +145,25 @@ func (s *ReservoirItemsSketch[T]) Update(item T) { } s.data = append(s.data, item) + s.n++ } else { // Steady state: replace with probability k/n - j := rand.Int63n(s.n + 1) - if j < int64(s.k) { - s.data[j] = item + s.n++ + if rand.Float64()*float64(s.n) < float64(s.k) { + s.data[rand.Intn(s.k)] = item } } - s.n++ } func (s *ReservoirItemsSketch[T]) growReservoir() { - adjustedSize := adjustedSamplingAllocationSize(s.k, cap(s.data)< maxItemsSeen { + panic(fmt.Sprintf( + "sketch has exceeded capacity for total items seen. limit: %d, found: %d", + maxItemsSeen, s.n, + )) + } } // Serialization constants @@ -417,6 +461,9 @@ func NewReservoirItemsSketchFromSlice[T any](data []byte, serde ItemsSerDe[T]) ( } n := int64(binary.LittleEndian.Uint64(data[8:])) + if n > maxItemsSeen { + return nil, fmt.Errorf("items seen exceeds limit: %d", maxItemsSeen) + } numSamples := int(min(n, int64(k))) itemsData := data[preambleBytes:] diff --git a/sampling/reservoir_items_sketch_test.go b/sampling/reservoir_items_sketch_test.go index d223377..c1ef2d9 100644 --- a/sampling/reservoir_items_sketch_test.go +++ b/sampling/reservoir_items_sketch_test.go @@ -19,7 +19,6 @@ package sampling import ( "encoding/binary" - "math" "math/rand" "testing" @@ -78,6 +77,9 @@ func TestReservoirItemsSketchInvalidK(t *testing.T) { _, err = NewReservoirItemsSketch[int64](1) assert.ErrorContains(t, err, "k must be at least 2") + + _, err = NewReservoirItemsSketch[int64](16, WithReservoirItemsSketchResizeFactor(ResizeFactor(3))) + assert.ErrorContains(t, err, "unsupported resize factor") } func TestReservoirItemsSketch_Update(t *testing.T) { @@ -150,8 +152,9 @@ func TestReservoirItemsSketchReset(t *testing.T) { assert.NoError(t, err) ceilingLgK, _ := internal.ExactLog2(common.CeilingPowerOf2(k)) + lgRf, _ := resizeFactorLg(defaultResizeFactor) initialLgSize := startingSubMultiple( - ceilingLgK, int(math.Log2(float64(defaultResizeFactor))), minLgArrItems, + ceilingLgK, lgRf, minLgArrItems, ) expectedInitialCap := adjustedSamplingAllocationSize(k, 1<= 0.0 && summary.UpperBound <= 1.0) - assert.Equal(t, float64(sketch.NumSamples()), summary.TotalSketchWeight) + assert.True(t, summary.UpperBound >= 0.0 && summary.UpperBound <= float64(sketch.N())) + assert.Equal(t, float64(sketch.N()), summary.TotalSketchWeight) }) t.Run("EstimationModePredicateAlwaysMatches", func(t *testing.T) { @@ -253,10 +295,10 @@ func TestReservoirItemsSketchEstimateSubsetSum(t *testing.T) { summary, err := sketch.EstimateSubsetSum(func(int64) bool { return true }) assert.NoError(t, err) - assert.Equal(t, 1.0, summary.Estimate) - assert.Equal(t, 1.0, summary.UpperBound) - assert.True(t, summary.LowerBound >= 0.0 && summary.LowerBound <= 1.0) - assert.Equal(t, float64(sketch.NumSamples()), summary.TotalSketchWeight) + assert.Equal(t, float64(sketch.N()), summary.Estimate) + assert.Equal(t, float64(sketch.N()), summary.UpperBound) + assert.True(t, summary.LowerBound >= 0.0 && summary.LowerBound <= float64(sketch.N())) + assert.Equal(t, float64(sketch.N()), summary.TotalSketchWeight) }) t.Run("EstimationModePredicatePartiallyMatches", func(t *testing.T) { @@ -274,16 +316,16 @@ func TestReservoirItemsSketchEstimateSubsetSum(t *testing.T) { trueCount++ } } - expectedEstimate := float64(trueCount) / float64(len(samples)) + expectedEstimate := float64(sketch.N()) * (float64(trueCount) / float64(len(samples))) summary, err := sketch.EstimateSubsetSum(func(v int64) bool { return v%2 == 0 }) assert.NoError(t, err) assert.InDelta(t, expectedEstimate, summary.Estimate, 0.0) - assert.True(t, summary.LowerBound >= 0.0 && summary.LowerBound <= 1.0) - assert.True(t, summary.UpperBound >= 0.0 && summary.UpperBound <= 1.0) + assert.True(t, summary.LowerBound >= 0.0 && summary.LowerBound <= float64(sketch.N())) + assert.True(t, summary.UpperBound >= 0.0 && summary.UpperBound <= float64(sketch.N())) assert.True(t, summary.LowerBound <= summary.Estimate) assert.True(t, summary.Estimate <= summary.UpperBound) - assert.Equal(t, float64(sketch.NumSamples()), summary.TotalSketchWeight) + assert.Equal(t, float64(sketch.N()), summary.TotalSketchWeight) }) } @@ -301,3 +343,36 @@ func TestReservoirItemsSketchLegacySerVerEmpty(t *testing.T) { assert.Equal(t, 1024, sketch.K()) assert.Equal(t, ResizeX8, sketch.rf) } + +func TestReservoirItemsSketchUpdatePanicsAtMaxItemsSeen(t *testing.T) { + sketch, err := NewReservoirItemsSketch[int64](8) + assert.NoError(t, err) + sketch.n = maxItemsSeen + + assert.Panics(t, func() { + sketch.Update(1) + }) +} + +func TestReservoirItemsSketchForceIncrementItemsSeenPanicsOnOverflow(t *testing.T) { + sketch, err := NewReservoirItemsSketch[int64](8) + assert.NoError(t, err) + sketch.n = maxItemsSeen - 1 + + assert.Panics(t, func() { + sketch.forceIncrementItemsSeen(2) + }) +} + +func TestReservoirItemsSketchFromSliceRejectsNTooLarge(t *testing.T) { + data := make([]byte, 16) + data[0] = 0xC0 | preambleIntsNonEmpty + data[1] = serVer + data[2] = byte(internal.FamilyEnum.ReservoirItems.Id) + data[3] = 0 + binary.LittleEndian.PutUint32(data[4:], uint32(8)) + binary.LittleEndian.PutUint64(data[8:], uint64(maxItemsSeen+1)) + + _, err := NewReservoirItemsSketchFromSlice[int64](data, Int64SerDe{}) + assert.ErrorContains(t, err, "items seen exceeds limit") +} diff --git a/sampling/varopt_items_sketch.go b/sampling/varopt_items_sketch.go index 2b47028..4461a9f 100644 --- a/sampling/varopt_items_sketch.go +++ b/sampling/varopt_items_sketch.go @@ -91,7 +91,12 @@ func NewVarOptItemsSketch[T any](k int, opts ...VarOptOption) (*VarOptItemsSketc opt(cfg) } - initialSize := int(cfg.resizeFactor) + lgRf, err := resizeFactorLg(cfg.resizeFactor) + if err != nil { + return nil, err + } + + initialSize := 1 << lgRf if initialSize > k { initialSize = k } @@ -498,8 +503,9 @@ func (s *VarOptItemsSketch[T]) swap(i, j int) { // growDataArrays increases the capacity of data and weights arrays. func (s *VarOptItemsSketch[T]) growDataArrays() { + lgRf := mustResizeFactorLg(s.rf) prevSize := s.allocatedSize - newSize := s.adjustedSize(s.k, prevSize< varOptMaxK") } + + _, err = NewVarOptItemsSketch[string](16, WithResizeFactor(ResizeFactor(3))) + if err == nil { + t.Error("expected error for unsupported resize factor") + } } func TestVarOptItemsSketch_WarmupPhase(t *testing.T) { From 2a60aaf58f681bfe54ce395913fe21489b478ecd Mon Sep 17 00:00:00 2001 From: Fengzdadi <453788063@qq.com> Date: Sun, 8 Feb 2026 15:14:38 -0500 Subject: [PATCH 2/2] sampling: return errors instead of panicking in reservoir updates --- sampling/reservoir_items_sketch.go | 44 ++++++++++++++----------- sampling/reservoir_items_sketch_test.go | 24 +++++++------- sampling/reservoir_items_union.go | 42 ++++++++++++++--------- sampling/varopt_items_sketch.go | 2 +- 4 files changed, 62 insertions(+), 50 deletions(-) diff --git a/sampling/reservoir_items_sketch.go b/sampling/reservoir_items_sketch.go index d7e84ca..02f48d2 100644 --- a/sampling/reservoir_items_sketch.go +++ b/sampling/reservoir_items_sketch.go @@ -63,14 +63,6 @@ func resizeFactorLg(rf ResizeFactor) (int, error) { } } -func mustResizeFactorLg(rf ResizeFactor) int { - lgRf, err := resizeFactorLg(rf) - if err != nil { - panic(err) - } - return lgRf -} - // ReservoirItemsSketch provides a uniform random sample of items // from a stream of unknown size using the reservoir sampling algorithm. // @@ -133,15 +125,17 @@ func NewReservoirItemsSketch[T any]( } // Update adds an item to the sketch using reservoir sampling algorithm. -func (s *ReservoirItemsSketch[T]) Update(item T) { +func (s *ReservoirItemsSketch[T]) Update(item T) error { if s.n == maxItemsSeen { - panic(fmt.Sprintf("sketch has exceeded capacity for total items seen: %d", maxItemsSeen)) + return fmt.Errorf("sketch has exceeded capacity for total items seen: %d", maxItemsSeen) } if s.n < int64(s.k) { // Initial phase: store all items until reservoir is full if s.n >= int64(cap(s.data)) { - s.growReservoir() + if err := s.growReservoir(); err != nil { + return err + } } s.data = append(s.data, item) @@ -153,17 +147,22 @@ func (s *ReservoirItemsSketch[T]) Update(item T) { s.data[rand.Intn(s.k)] = item } } + return nil } -func (s *ReservoirItemsSketch[T]) growReservoir() { - lgRf := mustResizeFactorLg(s.rf) +func (s *ReservoirItemsSketch[T]) growReservoir() error { + lgRf, err := resizeFactorLg(s.rf) + if err != nil { + return err + } targetCap := adjustedSamplingAllocationSize(s.k, cap(s.data)< maxItemsSeen { - panic(fmt.Sprintf( + return fmt.Errorf( "sketch has exceeded capacity for total items seen. limit: %d, found: %d", maxItemsSeen, s.n, - )) + ) } + return nil } // Serialization constants diff --git a/sampling/reservoir_items_sketch_test.go b/sampling/reservoir_items_sketch_test.go index c1ef2d9..554cbe4 100644 --- a/sampling/reservoir_items_sketch_test.go +++ b/sampling/reservoir_items_sketch_test.go @@ -41,8 +41,8 @@ func TestReservoirItemsSketchWithStrings(t *testing.T) { assert.NoError(t, err) sketch.Update("apple") - sketch.Update("banana") - sketch.Update("cherry") + _ = sketch.Update("banana") + _ = sketch.Update("cherry") assert.Equal(t, int64(3), sketch.N()) assert.Equal(t, 3, sketch.NumSamples()) @@ -62,9 +62,9 @@ func TestReservoirItemsSketchWithStruct(t *testing.T) { sketch, err := NewReservoirItemsSketch[Event](5) assert.NoError(t, err) - sketch.Update(Event{1, "login"}) - sketch.Update(Event{2, "logout"}) - sketch.Update(Event{3, "click"}) + _ = sketch.Update(Event{1, "login"}) + _ = sketch.Update(Event{2, "logout"}) + _ = sketch.Update(Event{3, "click"}) assert.Equal(t, int64(3), sketch.N()) samples := sketch.Samples() @@ -344,24 +344,22 @@ func TestReservoirItemsSketchLegacySerVerEmpty(t *testing.T) { assert.Equal(t, ResizeX8, sketch.rf) } -func TestReservoirItemsSketchUpdatePanicsAtMaxItemsSeen(t *testing.T) { +func TestReservoirItemsSketchUpdateReturnsErrorAtMaxItemsSeen(t *testing.T) { sketch, err := NewReservoirItemsSketch[int64](8) assert.NoError(t, err) sketch.n = maxItemsSeen - assert.Panics(t, func() { - sketch.Update(1) - }) + err = sketch.Update(1) + assert.ErrorContains(t, err, "sketch has exceeded capacity") } -func TestReservoirItemsSketchForceIncrementItemsSeenPanicsOnOverflow(t *testing.T) { +func TestReservoirItemsSketchForceIncrementItemsSeenReturnsErrorOnOverflow(t *testing.T) { sketch, err := NewReservoirItemsSketch[int64](8) assert.NoError(t, err) sketch.n = maxItemsSeen - 1 - assert.Panics(t, func() { - sketch.forceIncrementItemsSeen(2) - }) + err = sketch.forceIncrementItemsSeen(2) + assert.ErrorContains(t, err, "sketch has exceeded capacity") } func TestReservoirItemsSketchFromSliceRejectsNTooLarge(t *testing.T) { diff --git a/sampling/reservoir_items_union.go b/sampling/reservoir_items_union.go index f626602..08904c5 100644 --- a/sampling/reservoir_items_union.go +++ b/sampling/reservoir_items_union.go @@ -53,11 +53,15 @@ func NewReservoirItemsUnion[T any](maxK int) (*ReservoirItemsUnion[T], error) { } // Update adds a single item to the union. -func (u *ReservoirItemsUnion[T]) Update(item T) { +func (u *ReservoirItemsUnion[T]) Update(item T) error { if u.gadget == nil { - u.gadget, _ = NewReservoirItemsSketch[T](u.maxK) + var err error + u.gadget, err = NewReservoirItemsSketch[T](u.maxK) + if err != nil { + return err + } } - u.gadget.Update(item) + return u.gadget.Update(item) } // UpdateSketch merges another sketch into the union. @@ -85,8 +89,7 @@ func (u *ReservoirItemsUnion[T]) UpdateSketch(sketch *ReservoirItemsSketch[T]) e return nil } - u.twoWayMergeInternal(ris) - return nil + return u.twoWayMergeInternal(ris) } // UpdateFromRaw creates a sketch from raw components and merges it. @@ -127,7 +130,9 @@ func (u *ReservoirItemsUnion[T]) createNewGadget(source *ReservoirItemsSketch[T] if err != nil { return err } - u.twoWayMergeInternalStandard(source) + if err := u.twoWayMergeInternalStandard(source); err != nil { + return err + } } else { u.gadget = source.Copy() @@ -137,40 +142,43 @@ func (u *ReservoirItemsUnion[T]) createNewGadget(source *ReservoirItemsSketch[T] // twoWayMergeInternal performs the merge based on the state of both sketches. // This implements Java's twoWayMergeInternal logic. -func (u *ReservoirItemsUnion[T]) twoWayMergeInternal(source *ReservoirItemsSketch[T]) { +func (u *ReservoirItemsUnion[T]) twoWayMergeInternal(source *ReservoirItemsSketch[T]) error { if source.N() <= int64(source.K()) { // Case 1: source is in exact mode - use standard merge - u.twoWayMergeInternalStandard(source) + return u.twoWayMergeInternalStandard(source) } else if u.gadget.N() < int64(u.gadget.K()) { // Case 2: gadget is in exact mode, source is in sampling mode // Swap: merge gadget into source (source becomes new gadget) tmp := u.gadget u.gadget = source.Copy() - u.twoWayMergeInternalStandard(tmp) + return u.twoWayMergeInternalStandard(tmp) } else if source.ImplicitSampleWeight() < float64(u.gadget.N())/float64(u.gadget.K()-1) { // Case 3: both in sampling mode, source is "lighter" // Merge source into gadget - u.twoWayMergeInternalWeighted(source) + return u.twoWayMergeInternalWeighted(source) } else { // Case 4: both in sampling mode, gadget is "lighter" // Swap: merge gadget into source tmp := u.gadget u.gadget = source.Copy() - u.twoWayMergeInternalWeighted(tmp) + return u.twoWayMergeInternalWeighted(tmp) } } // twoWayMergeInternalStandard merges a sketch in exact mode (N <= K) into gadget. // Simply updates gadget with each item from source. -func (u *ReservoirItemsUnion[T]) twoWayMergeInternalStandard(source *ReservoirItemsSketch[T]) { +func (u *ReservoirItemsUnion[T]) twoWayMergeInternalStandard(source *ReservoirItemsSketch[T]) error { for i := 0; i < source.NumSamples(); i++ { - u.gadget.Update(source.valueAtPosition(i)) + if err := u.gadget.Update(source.valueAtPosition(i)); err != nil { + return err + } } + return nil } // twoWayMergeInternalWeighted merges a "lighter" sketch into gadget using weighted sampling. // Uses the correct probability formula: P = (K * w) / targetTotal -func (u *ReservoirItemsUnion[T]) twoWayMergeInternalWeighted(source *ReservoirItemsSketch[T]) { +func (u *ReservoirItemsUnion[T]) twoWayMergeInternalWeighted(source *ReservoirItemsSketch[T]) error { numSourceSamples := source.K() sourceItemWeight := float64(source.N()) / float64(numSourceSamples) rescaledProb := float64(u.gadget.K()) * sourceItemWeight @@ -188,7 +196,7 @@ func (u *ReservoirItemsUnion[T]) twoWayMergeInternalWeighted(source *ReservoirIt } } - u.gadget.forceIncrementItemsSeen(source.N()) + return u.gadget.forceIncrementItemsSeen(source.N()) } // Result returns a copy of the internal sketch. @@ -315,7 +323,9 @@ func NewReservoirItemsUnionFromSlice[T any](data []byte, serde ItemsSerDe[T]) (* if err != nil { return nil, err } - union.UpdateSketch(sketch) + if err := union.UpdateSketch(sketch); err != nil { + return nil, err + } } return union, nil diff --git a/sampling/varopt_items_sketch.go b/sampling/varopt_items_sketch.go index 4461a9f..75b91f5 100644 --- a/sampling/varopt_items_sketch.go +++ b/sampling/varopt_items_sketch.go @@ -503,7 +503,7 @@ func (s *VarOptItemsSketch[T]) swap(i, j int) { // growDataArrays increases the capacity of data and weights arrays. func (s *VarOptItemsSketch[T]) growDataArrays() { - lgRf := mustResizeFactorLg(s.rf) + lgRf, _ := resizeFactorLg(s.rf) prevSize := s.allocatedSize newSize := s.adjustedSize(s.k, prevSize<