From b11a98e56a75b1b26342f1c36064e0b4f9e0ce15 Mon Sep 17 00:00:00 2001 From: lani_karrot Date: Sat, 14 Feb 2026 16:40:02 +0900 Subject: [PATCH] fix: VarOptItemsSketch update --- sampling/reservoir_items_sketch.go | 13 +- sampling/sampling.go | 31 ++ sampling/varopt_items_sketch.go | 403 ++++++++++++++++-------- sampling/varopt_items_sketch_test.go | 452 ++++++++++++++++----------- 4 files changed, 565 insertions(+), 334 deletions(-) create mode 100644 sampling/sampling.go diff --git a/sampling/reservoir_items_sketch.go b/sampling/reservoir_items_sketch.go index a88a14b..526f322 100644 --- a/sampling/reservoir_items_sketch.go +++ b/sampling/reservoir_items_sketch.go @@ -30,18 +30,7 @@ import ( "github.com/apache/datasketches-go/internal" ) -// ResizeFactor controls how the internal array grows. -// Note: Go's slice append has automatic resizing, so this is kept for -// API compatibility with the Java version. Can be removed if not needed. -// TODO: In Java, this is abstracted into a common package. Consider if this should be moved to a common package in the future. -type ResizeFactor int - const ( - ResizeX1 ResizeFactor = 1 - ResizeX2 ResizeFactor = 2 - ResizeX4 ResizeFactor = 4 - ResizeX8 ResizeFactor = 8 - defaultResizeFactor = ResizeX8 minK = 2 @@ -95,7 +84,7 @@ func NewReservoirItemsSketch[T any]( ceilingLgK, _ := internal.ExactLog2(common.CeilingPowerOf2(k)) initialLgSize := startingSubMultiple( - ceilingLgK, int(math.Log2(float64(options.resizeFactor))), minLgArrItems, + ceilingLgK, int(float64(options.resizeFactor)), minLgArrItems, ) return &ReservoirItemsSketch[T]{ k: k, diff --git a/sampling/sampling.go b/sampling/sampling.go new file mode 100644 index 0000000..d200475 --- /dev/null +++ b/sampling/sampling.go @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package sampling + +// ResizeFactor controls how the internal array grows. +// Note: Go's slice append has automatic resizing, so this is kept for +// API compatibility with the Java version. Can be removed if not needed. +// TODO: In Java, this is abstracted into a common package. Consider if this should be moved to a common package in the future. +type ResizeFactor int + +const ( + ResizeX1 ResizeFactor = 0 + ResizeX2 ResizeFactor = 1 + ResizeX4 ResizeFactor = 2 + ResizeX8 ResizeFactor = 3 +) diff --git a/sampling/varopt_items_sketch.go b/sampling/varopt_items_sketch.go index 2b47028..c41664c 100644 --- a/sampling/varopt_items_sketch.go +++ b/sampling/varopt_items_sketch.go @@ -22,6 +22,9 @@ import ( "iter" "math" "math/rand" + "slices" + + "github.com/apache/datasketches-go/internal" ) // VarOptItemsSketch implements variance-optimal weighted sampling. @@ -39,27 +42,34 @@ import ( // Reference: Cohen et al., "Efficient Stream Sampling for Variance-Optimal // Estimation of Subset Sums", SIAM J. Comput. 40(5): 1402-1431, 2011. type VarOptItemsSketch[T any] struct { - k int // maximum sample size (user-configured) - n int64 // total number of items processed - h int // number of items in H (heavy/heap) region - m int // number of items in middle region (during candidate set operations) - r int // number of items in R (reservoir) region - totalWeightR float64 // total weight of items in R region - data []T // stored items - weights []float64 // corresponding weights for each item (-1.0 indicates R region) + data []T // stored items + weights []float64 // corresponding weights for each item (-1.0 indicates R region) + + // The following array is absent in a varopt sketch, and notionally present in a gadget + // (although it really belongs in the unioning object). If the array were to be made explicit, + // some additional coding would need to be done to ensure that all of the necessary data motion + // occurs and is properly tracked. + marks []bool + + k int // maximum sample size (user-configured) + n int64 // total number of items processed + h int // number of items in H (heavy/heap) region + m int // number of items in middle region (during candidate set operations) + r int // number of items in R (reservoir) region + totalWeightR float64 // total weight of items in R region // resize factor for array growth rf ResizeFactor - // current allocated capacity - allocatedSize int + // Following int is: + // 1. Zero (for a varopt sketch) + // 2. Count of marked items in H region, if part of a unioning algo's gadget + numMarksInH uint32 } const ( // VarOpt specific constants varOptDefaultResizeFactor = ResizeX8 - varOptMinLgK = 3 // minimum log2(k) = 3, so minimum k = 8 - varOptMinK = 1 << varOptMinLgK varOptMaxK = (1 << 31) - 2 // maximum k value ) @@ -75,38 +85,36 @@ func WithResizeFactor(rf ResizeFactor) VarOptOption { } } -func NewVarOptItemsSketch[T any](k int, opts ...VarOptOption) (*VarOptItemsSketch[T], error) { - if k < varOptMinK { - return nil, errors.New("k must be at least 8") - } - if k > varOptMaxK { - return nil, errors.New("k is too large") +func NewVarOptItemsSketch[T any](k uint, opts ...VarOptOption) (*VarOptItemsSketch[T], error) { + if k < 1 || k > varOptMaxK { + return nil, errors.New("k must be at least 1 and less than 2^31 - 1") } cfg := &varOptConfig{ resizeFactor: varOptDefaultResizeFactor, } - for _, opt := range opts { opt(cfg) } - initialSize := int(cfg.resizeFactor) - if initialSize > k { - initialSize = k + ceilingLgK := math.Log2(float64(internal.CeilPowerOf2(int(k)))) + initialLgSize := startingSubMultiple(int(ceilingLgK), int(cfg.resizeFactor), minLgArrItems) + currItemsAlloc := adjustedSamplingAllocationSize(int(k), 1< 0 { - tau := s.totalWeightR / float64(s.r) - rStart := s.h + s.m - for i := 0; i < s.r; i++ { - if !yield(Sample[T]{Item: s.data[rStart+i], Weight: tau}) { + tau := s.tau() + for i := s.h + 1; i <= s.k; i++ { + if !yield(Sample[T]{Item: s.data[i], Weight: tau}) { return } } @@ -180,62 +198,76 @@ func (s *VarOptItemsSketch[T]) inWarmup() bool { } // peekMin returns the minimum weight in the H region (heap root). -func (s *VarOptItemsSketch[T]) peekMin() float64 { +func (s *VarOptItemsSketch[T]) peekMin() (float64, error) { if s.h == 0 { - return math.Inf(1) + return 0, errors.New("h = 0 when checking min in H region") } - return s.weights[0] + return s.weights[0], nil } // Update adds an item with the given weight to the sketch. // Weight must be positive and finite. func (s *VarOptItemsSketch[T]) Update(item T, weight float64) error { - if weight < 0 || math.IsNaN(weight) || math.IsInf(weight, 0) { - return errors.New("weight must be nonnegative and finite") - } - if weight == 0 { - return nil // ignore zero weight items + return s.update(item, weight, false) +} + +func (s *VarOptItemsSketch[T]) update(item T, weight float64, mark bool) error { + if weight <= 0 || math.IsNaN(weight) || math.IsInf(weight, 0) { + return errors.New("weight must be strictly positive and finite") } s.n++ if s.r == 0 { // exact mode (warmup) - return s.updateWarmupPhase(item, weight) + return s.updateWarmupPhase(item, weight, mark) } - // estimation mode - // what tau would be if deletion candidates = R + new item - hypotheticalTau := (weight + s.totalWeightR) / float64(s.r) // r+1-1 = r + + minWeight, err := s.peekMin() + if err != nil { + return err + } + + if s.h != 0 && minWeight < s.tau() { + return errors.New("sketch not in valid estimation mode") + } + + // what tau would be if deletion candidates turn out to be R plus the new item + // NOTE: (r_ + 1) - 1 is intentional + hypotheticalTau := (weight + s.totalWeightR) / (float64(s.r+1) - 1) // is new item's turn to be considered for reservoir? - condition1 := s.h == 0 || weight <= s.peekMin() + condition1 := s.h == 0 || weight <= minWeight + // is new item light enough for reservoir? condition2 := weight < hypotheticalTau if condition1 && condition2 { - return s.updateLight(item, weight) + return s.updateLight(item, weight, mark) } else if s.r == 1 { - return s.updateHeavyREq1(item, weight) + return s.updateHeavyREqualsTo1(item, weight, mark) } - return s.updateHeavyGeneral(item, weight) + return s.updateHeavyGeneral(item, weight, mark) } // updateWarmupPhase handles the warmup phase when r=0. -func (s *VarOptItemsSketch[T]) updateWarmupPhase(item T, weight float64) error { +func (s *VarOptItemsSketch[T]) updateWarmupPhase(item T, weight float64, mark bool) error { if s.h >= cap(s.data) { - s.growDataArrays() + s.growArrays() } - // store items as they come in - if s.h < len(s.data) { - s.data[s.h] = item - s.weights[s.h] = weight - } else { - s.data = append(s.data, item) - s.weights = append(s.weights, weight) + // store items until full + s.data = append(s.data, item) + s.weights = append(s.weights, weight) + if s.marks != nil { + s.marks = append(s.marks, mark) } s.h++ + if mark { + s.numMarksInH++ + } + // check if need to transition to estimation mode if s.h > s.k { return s.transitionFromWarmup() @@ -245,18 +277,28 @@ func (s *VarOptItemsSketch[T]) updateWarmupPhase(item T, weight float64) error { // transitionFromWarmup converts from warmup (exact) mode to estimation mode. func (s *VarOptItemsSketch[T]) transitionFromWarmup() error { - // Convert to heap and move 2 lightest items to M region - s.heapify() - s.popMinToMRegion() - s.popMinToMRegion() + // Move the 2 lightest items from H to M + // But the lighter really belongs in R, so update counts to reflect that + if err := s.heapify(); err != nil { + return err + } + if err := s.popMinToMRegion(); err != nil { + return err + } + if err := s.popMinToMRegion(); err != nil { + return err + } // The lighter of the two really belongs in R s.m-- s.r++ - // h should be k-1, m should be 1, r should be 1 + if s.h != (s.k-1) || s.m != 1 || s.r != 1 { + return errors.New("invalid state for transitioning from warmup") + } - // Update total weight in R (the item at position k) + // Update total weight in R and then, having grabbed the value, overwrite + // in weight_ array to help make bugs more obvious s.totalWeightR = s.weights[s.k] s.weights[s.k] = -1.0 // mark as R region item @@ -265,52 +307,86 @@ func (s *VarOptItemsSketch[T]) transitionFromWarmup() error { return s.growCandidateSet(s.weights[s.k-1]+s.totalWeightR, 2) } -// updateLight handles a light item (weight <= old_tau) in estimation mode. -func (s *VarOptItemsSketch[T]) updateLight(item T, weight float64) error { +// NOTE: In the "light" case the new item has weight <= old_tau, so +// would appear to the right of the R items in a hypothetical reverse-sorted +// list. It is easy to prove that it is light enough to be part of this +// round's downsampling +func (s *VarOptItemsSketch[T]) updateLight(item T, weight float64, mark bool) error { + if s.r == 0 || (s.r+s.h) != s.k { + return errors.New("invalid sketch state during light warmup") + } + // The M slot is at index h (the gap) mSlot := s.h s.data[mSlot] = item s.weights[mSlot] = weight + if s.marks != nil { + s.marks[mSlot] = mark + } s.m++ return s.growCandidateSet(s.totalWeightR+weight, s.r+1) } -// updateHeavyGeneral handles a heavy item when r >= 2. -func (s *VarOptItemsSketch[T]) updateHeavyGeneral(item T, weight float64) error { +// NOTE: In the "heavy" case the new item has weight > old_tau, so would +// appear to the left of items in R in a hypothetical reverse-sorted list and +// might or might not be light enough be part of this round's downsampling. +// [After first splitting off the R=1 case] we greatly simplify the code by +// putting the new item into the H heap whether it needs to be there or not. +// In other words, it might go into the heap and then come right back out, +// but that should be okay because pseudo_heavy items cannot predominate +// in long streams unless (max wt) / (min wt) > o(exp(N)) +func (s *VarOptItemsSketch[T]) updateHeavyGeneral(item T, weight float64, mark bool) error { + if s.r < 2 || s.m != 0 || (s.r+s.h) != s.k { + return errors.New("invalid sketch state during heavy general update") + } + // Put into H (may come back out momentarily) - s.push(item, weight) + s.push(item, weight, mark) return s.growCandidateSet(s.totalWeightR, s.r) } -// updateHeavyREq1 handles a heavy item when r == 1. -func (s *VarOptItemsSketch[T]) updateHeavyREq1(item T, weight float64) error { - s.push(item, weight) // new item into H - s.popMinToMRegion() // pop lightest back into M +// NOTE: The analysis of this case is similar to that of the general heavy case. +// The one small technical difference is that since R < 2, we must grab an M item +// to have a valid starting point for growCandidateSet +func (s *VarOptItemsSketch[T]) updateHeavyREqualsTo1(item T, weight float64, mark bool) error { + if s.r != 1 || s.m != 0 || (s.r+s.h) != s.k { + return errors.New("invalid sketch state during heavy r=1 update") + } + + s.push(item, weight, mark) // new item into H + if err := s.popMinToMRegion(); err != nil { // pop lightest back into M + return err + } - // The M slot is at k-1 (array is k+1, 1 in R) - mSlot := s.k - 1 + // Any set of two items is downsample-able to one item, + // so the two lightest items are a valid starting point for the following + mSlot := s.k - 1 // The M slot is at k-1 (array is k+1, 1 in R) return s.growCandidateSet(s.weights[mSlot]+s.totalWeightR, 2) } -// push adds an item to the H region heap. -func (s *VarOptItemsSketch[T]) push(item T, weight float64) { - // Insert at position h (the gap) +func (s *VarOptItemsSketch[T]) push(item T, weight float64, mark bool) { s.data[s.h] = item s.weights[s.h] = weight + if s.marks != nil { + s.marks[s.h] = mark + if mark { + s.numMarksInH++ + } + } s.h++ - s.siftUp(s.h - 1) + s.restoreTowardsRoot(s.h - 1) } // popMinToMRegion moves the minimum item from H to M region. -func (s *VarOptItemsSketch[T]) popMinToMRegion() { - if s.h == 0 { - return +func (s *VarOptItemsSketch[T]) popMinToMRegion() error { + if s.h == 0 || (s.h+s.m+s.r) != (s.k+1) { + return errors.New("invalid heap state popping min to M region") } - if s.h == 1 { + if s.h == 1 { // just update bookkeeping s.m++ s.h-- } else { @@ -318,21 +394,49 @@ func (s *VarOptItemsSketch[T]) popMinToMRegion() { s.swap(0, tgt) s.m++ s.h-- - s.siftDown(0) + + if err := s.restoreTowardsLeaves(0); err != nil { + return err + } + } + + if s.isMarked() { + s.numMarksInH-- } + return nil } -// growCandidateSet grows the candidate set by pulling light items from H to M. +func (s *VarOptItemsSketch[T]) isMarked() bool { + return s.marks != nil && s.marks[s.h] +} + +// NOTE: When entering here we should be in a well-characterized state where the +// new item has been placed in either h or m and we have a valid but not necessarily +// maximal sampling plan figured out. The array is completely full at this point. +// Everyone in h and m has an explicit weight. The candidates are right-justified +// and are either just the r set or the r set + exactly one m item. The number +// of cands is at least 2. We will now grow the candidate set as much as possible +// by pulling sufficiently light items from h to m. func (s *VarOptItemsSketch[T]) growCandidateSet(wtCands float64, numCands int) error { + if (s.h+s.m+s.r != s.k+1) || numCands < 1 || numCands != (s.m+s.r) || s.m >= 2 { + return errors.New("invariant violated when growing candidate set") + } + for s.h > 0 { - nextWt := s.peekMin() + nextWt, err := s.peekMin() + if err != nil { + return err + } + nextTotWt := wtCands + nextWt // test for strict lightness: nextWt * numCands < nextTotWt if nextWt*float64(numCands) < nextTotWt { wtCands = nextTotWt numCands++ - s.popMinToMRegion() + if err := s.popMinToMRegion(); err != nil { + return err + } } else { break } @@ -343,8 +447,8 @@ func (s *VarOptItemsSketch[T]) growCandidateSet(wtCands float64, numCands int) e // downsampleCandidateSet downsamples the candidate set to produce final R. func (s *VarOptItemsSketch[T]) downsampleCandidateSet(wtCands float64, numCands int) error { - if numCands < 2 { - return nil + if numCands < 2 || s.h+numCands != s.k+1 { + return errors.New("invalid numCands when downsampling") } // Choose which slot to delete @@ -354,18 +458,20 @@ func (s *VarOptItemsSketch[T]) downsampleCandidateSet(wtCands float64, numCands } leftmostCandSlot := s.h + if deleteSlot < leftmostCandSlot || deleteSlot > s.k { + return errors.New("invalid delete slot index when downsampling") + } - // Mark weights for items moving from M to R as -1 + // Overwrite weights for items from M moving into R, + // to make bugs more obvious. Also needed so anyone reading the + // weight knows if it's invalid without checking h and m stopIdx := leftmostCandSlot + s.m for j := leftmostCandSlot; j < stopIdx; j++ { s.weights[j] = -1.0 } - // Move the delete slot's content to leftmost candidate position - // This works even when deleteSlot == leftmostCandSlot - if deleteSlot != leftmostCandSlot { - s.data[deleteSlot] = s.data[leftmostCandSlot] - } + // The next line works even when delete_slot == leftmost_cand_slot + s.data[deleteSlot] = s.data[leftmostCandSlot] s.m = 0 s.r = numCands - 1 @@ -376,26 +482,27 @@ func (s *VarOptItemsSketch[T]) downsampleCandidateSet(wtCands float64, numCands // chooseDeleteSlot randomly selects which item to delete from candidates. func (s *VarOptItemsSketch[T]) chooseDeleteSlot(wtCands float64, numCands int) (int, error) { if s.r == 0 { - return 0, errors.New("chooseDeleteSlot called while in exact mode (r == 0)") + return 0, errors.New("choosing delete slot while in exact mode") } - if s.m == 0 { - // All candidates are in R, pick random slot - return s.randomRIndex(), nil - } else if s.m == 1 { + switch s.m { + case 0: + // this happens if we insert a really heavy item + return s.pickRandomSlotInR() + case 1: // Check if we keep the item in M or pick one from R // p(keep) = (numCands - 1) * wtM / wtCands wtMCand := s.weights[s.h] // slot of item in M is h if wtCands*s.randFloat64NonZero() < float64(numCands-1)*wtMCand { - return s.randomRIndex(), nil // keep item in M + return s.pickRandomSlotInR() // keep item in M } return s.h, nil // delete item in M - } else { + default: // General case with multiple M items deleteSlot := s.chooseWeightedDeleteSlot(wtCands, numCands) firstRSlot := s.h + s.m if deleteSlot == firstRSlot { - return s.randomRIndex(), nil + return s.pickRandomSlotInR() } return deleteSlot, nil } @@ -423,13 +530,17 @@ func (s *VarOptItemsSketch[T]) chooseWeightedDeleteSlot(wtCands float64, numCand return finalM + 1 } -// randomRIndex returns a random index from the R region. -func (s *VarOptItemsSketch[T]) randomRIndex() int { +// pickRandomSlotInR returns a random index from the R region. +func (s *VarOptItemsSketch[T]) pickRandomSlotInR() (int, error) { + if s.r == 0 { + return 0, errors.New("r == 0 when picking slot in R region") + } + offset := s.h + s.m if s.r == 1 { - return offset + return offset, nil } - return offset + rand.Intn(s.r) + return offset + rand.Intn(s.r), nil } // randFloat64NonZero returns a random float64 in (0, 1). @@ -442,23 +553,30 @@ func (s *VarOptItemsSketch[T]) randFloat64NonZero() float64 { } } -// heapify converts H region to a valid min-heap. -func (s *VarOptItemsSketch[T]) heapify() { +// heapify converts data and weights to heap. +func (s *VarOptItemsSketch[T]) heapify() error { if s.h < 2 { - return + return nil } lastSlot := s.h - 1 lastNonLeaf := ((lastSlot + 1) / 2) - 1 for j := lastNonLeaf; j >= 0; j-- { - s.siftDown(j) + if err := s.restoreTowardsLeaves(j); err != nil { + return err + } } + return nil } // siftDown restores heap property by moving element down. -func (s *VarOptItemsSketch[T]) siftDown(slotIn int) { +func (s *VarOptItemsSketch[T]) restoreTowardsLeaves(slotIn int) error { lastSlot := s.h - 1 + if s.h == 0 || slotIn > lastSlot { + return errors.New("invalid heap state") + } + slot := slotIn child := 2*slotIn + 1 @@ -476,10 +594,18 @@ func (s *VarOptItemsSketch[T]) siftDown(slotIn int) { slot = child child = 2*slot + 1 } + + return nil +} + +func (s *VarOptItemsSketch[T]) tau() float64 { + if s.r == 0 { + return math.NaN() + } + return s.totalWeightR / float64(s.r) } -// siftUp restores heap property by moving element up. -func (s *VarOptItemsSketch[T]) siftUp(slotIn int) { +func (s *VarOptItemsSketch[T]) restoreTowardsRoot(slotIn int) { slot := slotIn p := ((slot + 1) / 2) - 1 // parent @@ -490,37 +616,36 @@ func (s *VarOptItemsSketch[T]) siftUp(slotIn int) { } } -// swap exchanges items at two positions. func (s *VarOptItemsSketch[T]) swap(i, j int) { s.data[i], s.data[j] = s.data[j], s.data[i] s.weights[i], s.weights[j] = s.weights[j], s.weights[i] + + if s.marks != nil { + s.marks[i], s.marks[j] = s.marks[j], s.marks[i] + } } -// growDataArrays increases the capacity of data and weights arrays. -func (s *VarOptItemsSketch[T]) growDataArrays() { - prevSize := s.allocatedSize +func (s *VarOptItemsSketch[T]) growArrays() { + prevSize := cap(s.data) newSize := s.adjustedSize(s.k, prevSize< prevSize { - newData := make([]T, len(s.data), newSize) - copy(newData, s.data) - s.data = newData + if prevSize < newSize { + s.data = slices.Grow(s.data, newSize) + s.weights = slices.Grow(s.weights, newSize) - newWeights := make([]float64, len(s.weights), newSize) - copy(newWeights, s.weights) - s.weights = newWeights - - s.allocatedSize = newSize + if s.marks != nil { + s.marks = slices.Grow(s.marks, newSize) + } } } // adjustedSize returns the appropriate array size. func (s *VarOptItemsSketch[T]) adjustedSize(maxSize, resizeTarget int) int { - if resizeTarget <= maxSize { - return resizeTarget + if maxSize < (resizeTarget << 1) { + return maxSize } - return maxSize + return resizeTarget } diff --git a/sampling/varopt_items_sketch_test.go b/sampling/varopt_items_sketch_test.go index cd21127..0f9ca04 100644 --- a/sampling/varopt_items_sketch_test.go +++ b/sampling/varopt_items_sketch_test.go @@ -19,205 +19,291 @@ package sampling import ( "math" - "math/rand" "testing" + + "github.com/stretchr/testify/assert" ) -func TestVarOptItemsSketch_NewSketch(t *testing.T) { - // Test valid k - sketch, err := NewVarOptItemsSketch[string](16) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if sketch.K() != 16 { - t.Errorf("expected K=16, got %d", sketch.K()) - } - if sketch.N() != 0 { - t.Errorf("expected N=0, got %d", sketch.N()) - } - if !sketch.IsEmpty() { - t.Error("expected empty sketch") - } - - // Test k too small - _, err = NewVarOptItemsSketch[string](4) - if err == nil { - t.Error("expected error for k < 8") - } - - // Test k too large - _, err = NewVarOptItemsSketch[string](varOptMaxK + 1) - if err == nil { - t.Error("expected error for k > varOptMaxK") - } -} +func TestNewVarOptItemsSketch(t *testing.T) { + t.Run("valid K", func(t *testing.T) { + sketch, err := NewVarOptItemsSketch[string](16) + assert.NoError(t, err) + assert.Equal(t, 16, sketch.K()) + assert.Equal(t, int64(0), sketch.N()) + assert.True(t, sketch.IsEmpty()) + }) -func TestVarOptItemsSketch_WarmupPhase(t *testing.T) { - sketch, _ := NewVarOptItemsSketch[int](10) - - // Add fewer than k items - should all be stored - for i := 1; i <= 5; i++ { - err := sketch.Update(i, float64(i)) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - } - - if sketch.N() != 5 { - t.Errorf("expected N=5, got %d", sketch.N()) - } - if sketch.NumSamples() != 5 { - t.Errorf("expected NumSamples=5, got %d", sketch.NumSamples()) - } - if !sketch.inWarmup() { - t.Error("expected to still be in warmup mode") - } + t.Run("K is too large", func(t *testing.T) { + _, err := NewVarOptItemsSketch[string](varOptMaxK + 1) + assert.ErrorContains(t, err, "k must be at least 1 and less than 2^31 - 1") + }) } -func TestVarOptItemsSketch_TransitionToEstimation(t *testing.T) { - sketch, _ := NewVarOptItemsSketch[int](8) - - // Need k+1 items to trigger transition (h > k condition) - for i := 1; i <= 9; i++ { - err := sketch.Update(i, float64(i)) - if err != nil { - t.Fatalf("unexpected error at i=%d: %v", i, err) - } - } - - if sketch.N() != 9 { - t.Errorf("expected N=9, got %d", sketch.N()) - } - // After transition, H + R should equal k - if sketch.NumSamples() != 8 { - t.Errorf("expected NumSamples=8, got %d", sketch.NumSamples()) - } - // Should have transitioned out of warmup - if sketch.inWarmup() { - t.Error("expected to NOT be in warmup mode after filling") - } - // Should have some items in R region - if sketch.R() == 0 { - t.Error("expected R > 0 after transition") - } -} +func TestVarOptItemsSketch_NumSamples(t *testing.T) { + t.Run("empty sketch returns 0", func(t *testing.T) { + sketch, err := NewVarOptItemsSketch[int](10) + assert.NoError(t, err) + assert.Equal(t, 0, sketch.NumSamples()) + }) -func TestVarOptItemsSketch_EstimationMode(t *testing.T) { - sketch, _ := NewVarOptItemsSketch[int](8) - - // Fill and then add more - for i := 1; i <= 20; i++ { - err := sketch.Update(i, float64(i)) - if err != nil { - t.Fatalf("unexpected error at i=%d: %v", i, err) - } - } - - if sketch.N() != 20 { - t.Errorf("expected N=20, got %d", sketch.N()) - } - // Should still have at most k samples - if sketch.NumSamples() > sketch.K() { - t.Errorf("expected NumSamples <= K, got %d > %d", sketch.NumSamples(), sketch.K()) - } - // H + R should be <= k - if sketch.H()+sketch.R() > sketch.K() { - t.Errorf("expected H+R <= K, got %d+%d > %d", sketch.H(), sketch.R(), sketch.K()) - } -} + t.Run("fewer items than k returns number of items", func(t *testing.T) { + sketch, err := NewVarOptItemsSketch[int](10) + assert.NoError(t, err) + for i := 1; i <= 5; i++ { + err = sketch.Update(i, float64(i)) + assert.NoError(t, err) + } + assert.Equal(t, 5, sketch.NumSamples()) + }) + + t.Run("exactly k items returns k", func(t *testing.T) { + sketch, err := NewVarOptItemsSketch[int](10) + assert.NoError(t, err) + for i := 1; i <= 10; i++ { + err = sketch.Update(i, float64(i)) + assert.NoError(t, err) + } + assert.Equal(t, 10, sketch.NumSamples()) + }) -func TestVarOptItemsSketch_InvalidWeight(t *testing.T) { - sketch, _ := NewVarOptItemsSketch[string](8) - - // Negative weight should error - err := sketch.Update("a", -1.0) - if err == nil { - t.Error("expected error for negative weight") - } - - // Zero weight is valid in C++/Java - just ignored - err = sketch.Update("b", 0.0) - if err != nil { - t.Errorf("zero weight should be valid (ignored), got error: %v", err) - } - // Sketch should still be empty since zero weight is ignored - if sketch.N() != 0 { - t.Errorf("expected N=0 after zero weight update, got %d", sketch.N()) - } + t.Run("more than k items returns k", func(t *testing.T) { + k := 100 + sketch, err := NewVarOptItemsSketch[int](uint(k)) + assert.NoError(t, err) + for i := 1; i <= 200; i++ { + err = sketch.Update(i, float64(i)) + assert.NoError(t, err) + } + assert.Equal(t, k, sketch.NumSamples()) + }) } func TestVarOptItemsSketch_Reset(t *testing.T) { - sketch, _ := NewVarOptItemsSketch[int](8) - - for i := 1; i <= 10; i++ { - sketch.Update(i, float64(i)) - } - - sketch.Reset() - - if !sketch.IsEmpty() { - t.Error("expected empty after reset") - } - if sketch.N() != 0 { - t.Errorf("expected N=0 after reset, got %d", sketch.N()) - } - if sketch.H() != 0 || sketch.R() != 0 { - t.Errorf("expected H=0, R=0 after reset, got H=%d, R=%d", sketch.H(), sketch.R()) - } -} + t.Run("exact mode", func(t *testing.T) { + k := 10 + sketch, err := NewVarOptItemsSketch[int](uint(k)) + assert.NoError(t, err) + for i := 1; i <= 5; i++ { + err = sketch.Update(i, float64(i)) + assert.NoError(t, err) + } + assert.Equal(t, int64(5), sketch.N()) + assert.Equal(t, k, sketch.K()) + assert.False(t, sketch.IsEmpty()) -func TestVarOptItemsSketch_UniformWeights(t *testing.T) { - // With uniform weights, VarOpt should behave like reservoir sampling - sketch, _ := NewVarOptItemsSketch[int](10) + sketch.Reset() - for i := 1; i <= 100; i++ { - sketch.Update(i, 1.0) - } + assert.Equal(t, 10, sketch.K()) + assert.Equal(t, int64(0), sketch.N()) + assert.Equal(t, 0, sketch.NumSamples()) + assert.True(t, sketch.IsEmpty()) + assert.Equal(t, 0, sketch.H()) + assert.Equal(t, 0, sketch.R()) + }) + + t.Run("estimation mode", func(t *testing.T) { + k := 100 + sketch, err := NewVarOptItemsSketch[int](uint(k)) + assert.NoError(t, err) + for i := 1; i <= 200; i++ { + err = sketch.Update(i, float64(i)) + assert.NoError(t, err) + } + assert.Equal(t, int64(200), sketch.N()) + assert.Equal(t, k, sketch.K()) + assert.Equal(t, k, sketch.NumSamples()) - if sketch.N() != 100 { - t.Errorf("expected N=100, got %d", sketch.N()) - } - if sketch.NumSamples() > sketch.K() { - t.Errorf("expected NumSamples <= K, got %d", sketch.NumSamples()) - } + sketch.Reset() + + assert.Equal(t, k, sketch.K()) + assert.Equal(t, int64(0), sketch.N()) + assert.Equal(t, 0, sketch.NumSamples()) + assert.True(t, sketch.IsEmpty()) + assert.Equal(t, 0, sketch.H()) + assert.Equal(t, 0, sketch.R()) + }) } -func TestVarOptItemsSketch_CumulativeWeight(t *testing.T) { - // This test verifies that the sum of output weights equals the sum of input weights. - // This is a key property of VarOpt sketches. - // Matches C++ test: "varopt sketch: cumulative weight" - const eps = 1e-13 - k := 256 - n := 10 * k - - sketch, _ := NewVarOptItemsSketch[int](k) - - inputSum := 0.0 - for i := 0; i < n; i++ { - // Generate weights using exp(5*N(0,1)) to cover ~10 orders of magnitude - // This matches the C++ test distribution - w := math.Exp(5 * randNormal()) - inputSum += w - sketch.Update(i, w) - } - - // Get output weights using Go 1.23 iterator - outputSum := 0.0 - for sample := range sketch.All() { - outputSum += sample.Weight - } - - // The ratio should be exactly 1.0 (within floating point precision) - ratio := outputSum / inputSum - if math.Abs(ratio-1.0) > eps { - t.Errorf("weight ratio out of expected range: got %f, expected 1.0 (±%e)", ratio, eps) - } +func TestVarOptItemsSketch_All(t *testing.T) { + t.Run("empty sketch", func(t *testing.T) { + sketch, err := NewVarOptItemsSketch[int](10) + assert.NoError(t, err) + + count := 0 + for range sketch.All() { + count++ + } + assert.Equal(t, 0, count) + }) + + t.Run("exact mode", func(t *testing.T) { + sketch, err := NewVarOptItemsSketch[int](10) + assert.NoError(t, err) + + expectedWeights := map[int]float64{} + for i := 1; i <= 5; i++ { + w := float64(i) * 10.0 + err = sketch.Update(i, w) + assert.NoError(t, err) + expectedWeights[i] = w + } + + count := 0 + for sample := range sketch.All() { + w, ok := expectedWeights[sample.Item] + assert.True(t, ok, "unexpected item %d", sample.Item) + assert.Equal(t, w, sample.Weight) + count++ + } + assert.Equal(t, 5, count) + }) + + t.Run("estimation mode", func(t *testing.T) { + k := 100 + sketch, err := NewVarOptItemsSketch[int](uint(k)) + assert.NoError(t, err) + for i := 1; i <= 200; i++ { + err = sketch.Update(i, float64(i)) + assert.NoError(t, err) + } + + hCount := sketch.H() + rCount := sketch.R() + assert.Equal(t, k, hCount+rCount) + + tau := sketch.totalWeightR / float64(rCount) + + idx := 0 + for sample := range sketch.All() { + assert.True(t, sample.Weight > 0, "weight should be positive") + if idx >= hCount { + // R region items should have weight == tau + assert.InDelta(t, tau, sample.Weight, 1e-10) + } + idx++ + } + assert.Equal(t, k, idx) + }) + + t.Run("early break stops iteration", func(t *testing.T) { + sketch, err := NewVarOptItemsSketch[int](10) + assert.NoError(t, err) + for i := 1; i <= 5; i++ { + err = sketch.Update(i, float64(i)) + assert.NoError(t, err) + } + + count := 0 + for range sketch.All() { + count++ + if count == 3 { + break + } + } + assert.Equal(t, 3, count) + }) } -// randNormal returns a random number from standard normal distribution N(0,1) -func randNormal() float64 { - // Box-Muller transform - u1 := rand.Float64() - u2 := rand.Float64() - return math.Sqrt(-2*math.Log(u1)) * math.Cos(2*math.Pi*u2) +func TestVarOptItemsSketch_Update(t *testing.T) { + t.Run("negative weight", func(t *testing.T) { + sketch, err := NewVarOptItemsSketch[int](10) + assert.NoError(t, err) + + err = sketch.Update(1, -1.0) + assert.ErrorContains(t, err, "weight must be strictly positive and finite") + assert.Equal(t, int64(0), sketch.N()) + }) + + t.Run("zero weight", func(t *testing.T) { + sketch, err := NewVarOptItemsSketch[int](10) + assert.NoError(t, err) + + err = sketch.Update(1, 0.0) + assert.ErrorContains(t, err, "weight must be strictly positive and finite") + assert.Equal(t, int64(0), sketch.N()) + }) + + t.Run("NaN weight", func(t *testing.T) { + sketch, err := NewVarOptItemsSketch[int](10) + assert.NoError(t, err) + + err = sketch.Update(1, math.NaN()) + assert.ErrorContains(t, err, "weight must be strictly positive and finite") + assert.Equal(t, int64(0), sketch.N()) + }) + + t.Run("positive infinity weight", func(t *testing.T) { + sketch, err := NewVarOptItemsSketch[int](10) + assert.NoError(t, err) + + err = sketch.Update(1, math.Inf(1)) + assert.ErrorContains(t, err, "weight must be strictly positive and finite") + assert.Equal(t, int64(0), sketch.N()) + }) + + t.Run("negative infinity weight", func(t *testing.T) { + sketch, err := NewVarOptItemsSketch[int](10) + assert.NoError(t, err) + + err = sketch.Update(1, math.Inf(-1)) + assert.ErrorContains(t, err, "weight must be strictly positive and finite") + assert.Equal(t, int64(0), sketch.N()) + }) + + t.Run("exact mode", func(t *testing.T) { + sketch, err := NewVarOptItemsSketch[int](10) + assert.NoError(t, err) + + inputWeightSum := float64(0) + for i := 1; i <= 5; i++ { + w := float64(i) + + err = sketch.Update(i, w) + assert.NoError(t, err) + + inputWeightSum += w + } + + outputWeightSum := float64(0) + for sample := range sketch.All() { + outputWeightSum += sample.Weight + } + + assert.Equal(t, 5, sketch.H()) + assert.Equal(t, 0, sketch.R()) + assert.False(t, sketch.IsEmpty()) + + // check cumulative weight + weightRatio := outputWeightSum / inputWeightSum + assert.InDelta(t, weightRatio, 1.0, 1e-10) + }) + + t.Run("estimation mode", func(t *testing.T) { + k := 100 + sketch, err := NewVarOptItemsSketch[int](uint(k)) + assert.NoError(t, err) + + inputWeightSum := float64(0) + for i := 1; i <= 200; i++ { + w := float64(i) + + err = sketch.Update(i, w) + assert.NoError(t, err) + + inputWeightSum += w + } + + outputWeightSum := float64(0) + for sample := range sketch.All() { + outputWeightSum += sample.Weight + } + + assert.Equal(t, int64(200), sketch.N()) + assert.Equal(t, k, sketch.H()+sketch.R()) + assert.True(t, sketch.totalWeightR > 0) + + // check cumulative weight + weightRatio := outputWeightSum / inputWeightSum + assert.InDelta(t, weightRatio, 1.0, 1e-10) + }) }