From 3e027ce2e8d7e7aafa921dea064581420f6c4e5d Mon Sep 17 00:00:00 2001 From: lani_karrot Date: Wed, 18 Feb 2026 16:00:56 +0900 Subject: [PATCH] refactor: remove panic in CPC sketch --- cpc/cpc_compressed_state.go | 110 ++++++++++++++++++++----------- cpc/cpc_compressed_state_test.go | 15 +++-- 2 files changed, 80 insertions(+), 45 deletions(-) diff --git a/cpc/cpc_compressed_state.go b/cpc/cpc_compressed_state.go index 2480988..b77a646 100644 --- a/cpc/cpc_compressed_state.go +++ b/cpc/cpc_compressed_state.go @@ -18,6 +18,7 @@ package cpc import ( + "errors" "fmt" "math/bits" @@ -271,7 +272,10 @@ func (c *CpcCompressedState) compressHybridFlavor(src *CpcSketch) error { return fmt.Errorf("compressHybridFlavor: invariant violation (%d + %d != %d)", numPairsFromArray, srcNumPairs, srcNumCoupons) } - allPairs := trickyGetPairsFromWindow(srcSlidingWindow, srcK, numPairsFromArray, srcNumPairs) + allPairs, err := trickyGetPairsFromWindow(srcSlidingWindow, srcK, numPairsFromArray, srcNumPairs) + if err != nil { + return err + } mergePairs(srcPairArr, 0, srcNumPairs, allPairs, srcNumPairs, numPairsFromArray, allPairs, 0) return compressTheSurprisingValues(c, src, allPairs, int(srcNumCoupons)) } @@ -428,7 +432,10 @@ func (c *CpcCompressedState) compressSlidingFlavor(src *CpcSketch) error { } // Apply a transformation to the column indices. - pseudoPhase := determinePseudoPhase(src.lgK, int64(src.numCoupons)) + pseudoPhase, err := determinePseudoPhase(src.lgK, int64(src.numCoupons)) + if err != nil { + return err + } if pseudoPhase >= 16 { return fmt.Errorf("compressSlidingFlavor: pseudoPhase (%d) >= 16", pseudoPhase) } @@ -492,7 +499,10 @@ func (c *CpcCompressedState) uncompressSlidingFlavor(src *CpcSketch) error { } // Determine pseudoPhase. - pseudoPhase := determinePseudoPhase(srcLgK, int64(c.NumCoupons)) + pseudoPhase, err := determinePseudoPhase(srcLgK, int64(c.NumCoupons)) + if err != nil { + return err + } if pseudoPhase >= 16 { return fmt.Errorf("uncompressSlidingFlavor: pseudoPhase %d out of range", pseudoPhase) } @@ -608,7 +618,7 @@ func importFromMemory(bytes []byte) (*CpcCompressedState, error) { state.CwStream = getWStream(bytes) state.CsvStream = getSvStream(bytes) default: - panic("not implemented") + return nil, fmt.Errorf("unknown format: %d", format) } return state, nil } @@ -675,14 +685,23 @@ func compressTheSurprisingValues(target *CpcCompressedState, source *CpcSketch, // Compute srcK = 1 << source.lgK. srcK := 1 << source.lgK // Determine the number of base bits using a Golomb code decision. - numBaseBits := golombChooseNumberOfBaseBits(srcK+numPairs, numPairs) + numBaseBits, err := golombChooseNumberOfBaseBits(srcK+numPairs, numPairs) + if err != nil { + return err + } // Compute an upper-bound length for the compressed pairs buffer. - pairBufLen := safeLengthForCompressedPairBuf(srcK, numPairs, numBaseBits) + pairBufLen, err := safeLengthForCompressedPairBuf(srcK, numPairs, numBaseBits) + if err != nil { + return err + } // Allocate the buffer for compression. pairBuf := make([]int, pairBufLen) // lowLevelCompressPairs compresses 'pairs' using the chosen base bits into pairBuf. // It returns the number of ints that represent the compressed data. - csvLength := lowLevelCompressPairs(pairs, numPairs, numBaseBits, pairBuf) + csvLength, err := lowLevelCompressPairs(pairs, numPairs, numBaseBits, pairBuf) + if err != nil { + return err + } target.CsvLengthInts = csvLength target.CsvStream = pairBuf return nil @@ -696,7 +715,10 @@ func uncompressTheSurprisingValues(source *CpcCompressedState) ([]int, error) { } pairs := make([]int, numPairs) // Determine the number of base bits using the Golomb code decision. - numBaseBits := golombChooseNumberOfBaseBits(srcK+numPairs, numPairs) + numBaseBits, err := golombChooseNumberOfBaseBits(srcK+numPairs, numPairs) + if err != nil { + return nil, err + } // lowLevelUncompressPairs fills the 'pairs' slice using the compressed CSV stream. if err := lowLevelUncompressPairs(pairs, numPairs, numBaseBits, source.CsvStream, source.CsvLengthInts); err != nil { return nil, err @@ -704,24 +726,24 @@ func uncompressTheSurprisingValues(source *CpcCompressedState) ([]int, error) { return pairs, nil } -func golombChooseNumberOfBaseBits(k, count int) int { +func golombChooseNumberOfBaseBits(k, count int) (int, error) { if k < 1 || count < 1 { - panic("golombChooseNumberOfBaseBits: k and count must be >= 1") + return 0, errors.New("golombChooseNumberOfBaseBits: k and count must be >= 1") } quotient := (k - count) / count if quotient == 0 { - return 0 + return 0, nil } - return floorLog2(uint64(quotient)) + return floorLog2(uint64(quotient)), nil } func floorLog2(x uint64) int { return bits.Len64(x) - 1 } -func safeLengthForCompressedPairBuf(k, numPairs, numBaseBits int) int { +func safeLengthForCompressedPairBuf(k, numPairs, numBaseBits int) (int, error) { if numPairs <= 0 { - panic("safeLengthForCompressedPairBuf: numPairs must be > 0") + return 0, errors.New("safeLengthForCompressedPairBuf: numPairs must be > 0") } // Compute ybits = (numPairs * (1 + numBaseBits)) + (k >>> numBaseBits) ybits := int64(numPairs)*(1+int64(numBaseBits)) + (int64(k) >> uint(numBaseBits)) @@ -736,9 +758,9 @@ func safeLengthForCompressedPairBuf(k, numPairs, numBaseBits int) int { words := divideBy32RoundingUp(totalBits) // Ensure the number of words fits in a 31-bit int. if words >= (1 << 31) { - panic("safeLengthForCompressedPairBuf: words too large") + return 0, errors.New("safeLengthForCompressedPairBuf: words too large") } - return int(words) + return int(words), nil } func divideBy32RoundingUp(x int64) int64 { @@ -749,7 +771,7 @@ func divideBy32RoundingUp(x int64) int64 { return tmp + 1 } -func lowLevelCompressPairs(pairArray []int, numPairsToEncode, numBaseBits int, compressedWords []int) int { +func lowLevelCompressPairs(pairArray []int, numPairsToEncode, numBaseBits int, compressedWords []int) (int, error) { nextWordIndex := 0 var bitBuf uint64 = 0 bufBits := 0 @@ -773,8 +795,7 @@ func lowLevelCompressPairs(pairArray []int, numPairsToEncode, numBaseBits int, c predictedColIndex = 0 } if rowIndex < predictedRowIndex || colIndex < predictedColIndex { - panic(fmt.Sprintf("lowLevelCompressPairs: assertion failed: rowIndex=%d, predictedRowIndex=%d, colIndex=%d, predictedColIndex=%d", - rowIndex, predictedRowIndex, colIndex, predictedColIndex)) + return 0, fmt.Errorf("lowLevelCompressPairs: assertion failed: rowIndex=%d, predictedRowIndex=%d, colIndex=%d, predictedColIndex=%d", rowIndex, predictedRowIndex, colIndex, predictedColIndex) } // yDelta is the difference in row indices. @@ -846,7 +867,7 @@ func lowLevelCompressPairs(pairArray []int, numPairsToEncode, numBaseBits int, c compressedWords[nextWordIndex] = int(bitBuf & 0xFFFFFFFF) nextWordIndex++ } - return nextWordIndex + return nextWordIndex, nil } func lowLevelUncompressPairs(pairArray []int, numPairsToDecode, numBaseBits int, compressedWords []int, numCompressedWords int) error { @@ -889,7 +910,10 @@ func lowLevelUncompressPairs(pairArray []int, numPairsToDecode, numBaseBits int, ptrArr[NextWordIdx] = int64(nextWordIndex) ptrArr[BitBuf] = int64(bitBuf) ptrArr[BufBits] = int64(bufBits) - golombHi := readUnary(compressedWords, ptrArr) + golombHi, err := readUnary(compressedWords, ptrArr) + if err != nil { + return err + } // Retrieve updated values. nextWordIndex = int(ptrArr[NextWordIdx]) bitBuf = uint64(ptrArr[BitBuf]) @@ -931,7 +955,7 @@ func lowLevelUncompressPairs(pairArray []int, numPairsToDecode, numBaseBits int, return nil } -func readUnary(compressedWords []int, ptrArr []int64) int64 { +func readUnary(compressedWords []int, ptrArr []int64) (int64, error) { nextWordIndex := int(ptrArr[NextWordIdx]) bitBuf := uint64(ptrArr[BitBuf]) bufBits := int(ptrArr[BufBits]) @@ -944,7 +968,7 @@ func readUnary(compressedWords []int, ptrArr []int64) int64 { // Ensure we have at least 8 bits in the bit buffer. if bufBits < 8 { if nextWordIndex >= len(compressedWords) { - panic("readUnary: insufficient compressedWords data") + return 0, errors.New("readUnary: insufficient compressedWords data") } bitBuf |= (uint64(compressedWords[nextWordIndex]) & 0xFFFFFFFF) << uint(bufBits) nextWordIndex++ @@ -975,7 +999,7 @@ func readUnary(compressedWords []int, ptrArr []int64) int64 { ptrArr[BitBuf] = int64(bitBuf) ptrArr[BufBits] = int64(bufBits) - return subTotal + int64(trailingZeros) + return subTotal + int64(trailingZeros), nil } func writeUnary(compressedWords []int, ptrArr []int64, theValue int) { @@ -1011,7 +1035,7 @@ func writeUnary(compressedWords []int, ptrArr []int64, theValue int) { ptrArr[BufBits] = int64(bufBits) } -func trickyGetPairsFromWindow(window []byte, k, numPairsToGet, emptySpace int) []int { +func trickyGetPairsFromWindow(window []byte, k, numPairsToGet, emptySpace int) ([]int, error) { outputLength := emptySpace + numPairsToGet pairs := make([]int, outputLength) pairIndex := emptySpace @@ -1031,10 +1055,10 @@ func trickyGetPairsFromWindow(window []byte, k, numPairsToGet, emptySpace int) [ } if pairIndex != outputLength { - panic(fmt.Sprintf("trickyGetPairsFromWindow: pairIndex (%d) != outputLength (%d)", pairIndex, outputLength)) + return nil, fmt.Errorf("trickyGetPairsFromWindow: pairIndex (%d) != outputLength (%d)", pairIndex, outputLength) } - return pairs + return pairs, nil } func (c *CpcCompressedState) compressTheWindow(src *CpcSketch) error { @@ -1045,7 +1069,10 @@ func (c *CpcCompressedState) compressTheWindow(src *CpcSketch) error { windowBufLen := safeLengthForCompressedWindowBuf(int64(srcK)) windowBuf := make([]int, windowBufLen) // Determine the pseudo-phase using srcLgK and the number of coupons. - pseudoPhase := determinePseudoPhase(srcLgK, int64(src.numCoupons)) + pseudoPhase, err := determinePseudoPhase(srcLgK, int64(src.numCoupons)) + if err != nil { + return err + } // Compress the sliding window bytes. // lowLevelCompressBytes is assumed to return (cwLengthInts int, err error). cwLengthInts := lowLevelCompressBytes(src.slidingWindow, srcK, encodingTablesForHighEntropyByte[pseudoPhase], windowBuf) @@ -1069,7 +1096,10 @@ func uncompressTheWindow(target *CpcSketch, source *CpcCompressedState) error { target.slidingWindow = window // Determine the pseudo-phase using srcLgK and source.NumCoupons. - pseudoPhase := determinePseudoPhase(srcLgK, int64(source.NumCoupons)) + pseudoPhase, err := determinePseudoPhase(srcLgK, int64(source.NumCoupons)) + if err != nil { + return err + } // Ensure that source.CwStream is not nil. if source.CwStream == nil { return fmt.Errorf("uncompressTheWindow: source.CwStream is nil") @@ -1091,37 +1121,37 @@ func safeLengthForCompressedWindowBuf(k int64) int { return int(divideBy32RoundingUp(totalBits)) } -func determinePseudoPhase(lgK int, numCoupons int64) int { +func determinePseudoPhase(lgK int, numCoupons int64) (int, error) { k := int64(1) << uint(lgK) c := numCoupons // Midrange logic. if (1000 * c) < (2375 * k) { if (4 * c) < (3 * k) { - return 16 + 0 + return 16 + 0, nil } else if (10 * c) < (11 * k) { - return 16 + 1 + return 16 + 1, nil } else if (100 * c) < (132 * k) { - return 16 + 2 + return 16 + 2, nil } else if (3 * c) < (5 * k) { - return 16 + 3 + return 16 + 3, nil } else if (1000 * c) < (1965 * k) { - return 16 + 4 + return 16 + 4, nil } else if (1000 * c) < (2275 * k) { - return 16 + 5 + return 16 + 5, nil } else { - return 6 // steady-state table employed before its actual phase. + return 6, nil // steady-state table employed before its actual phase. } } else { // Steady-state logic. if lgK < 4 { - panic("determinePseudoPhase: lgK must be at least 4") + return 0, errors.New("determinePseudoPhase: lgK must be at least 4") } tmp := c >> uint(lgK-4) phase := int(tmp & 15) if phase < 0 || phase >= 16 { - panic(fmt.Sprintf("determinePseudoPhase: phase out of range: %d", phase)) + return 0, fmt.Errorf("determinePseudoPhase: phase out of range: %d", phase) } - return phase + return phase, nil } } diff --git a/cpc/cpc_compressed_state_test.go b/cpc/cpc_compressed_state_test.go index 9aa5f32..6167eff 100644 --- a/cpc/cpc_compressed_state_test.go +++ b/cpc/cpc_compressed_state_test.go @@ -23,6 +23,8 @@ import ( "sort" "testing" + "github.com/stretchr/testify/assert" + "github.com/apache/datasketches-go/internal" ) @@ -84,7 +86,8 @@ func TestWriteReadUnary(t *testing.T) { if nextWordIndex != int(ptrArr[NextWordIdx]) { t.Errorf("Before readUnary: nextWordIndex %d != ptrArr[NextWordIdx] %d", nextWordIndex, ptrArr[NextWordIdx]) } - result := readUnary(compressedWords, ptrArr) + result, err := readUnary(compressedWords, ptrArr) + assert.NoError(t, err) t.Logf("Result: %d, expected: %d", result, i) if result != int64(i) { t.Errorf("Mismatch: got %d, expected %d", result, i) @@ -170,9 +173,10 @@ func TestWriteReadPairs(t *testing.T) { compressedWords := make([]int, MaxWords) // Loop over base bits 0 to 11. for bb := 0; bb <= 11; bb++ { - numWordsWritten := lowLevelCompressPairs(pairArray, numPairs, bb, compressedWords) + numWordsWritten, err := lowLevelCompressPairs(pairArray, numPairs, bb, compressedWords) + assert.NoError(t, err) t.Logf("numWordsWritten = %d, bb = %d", numWordsWritten, bb) - err := lowLevelUncompressPairs(pairArray2, numPairs, bb, compressedWords, numWordsWritten) + err = lowLevelUncompressPairs(pairArray2, numPairs, bb, compressedWords, numWordsWritten) if err != nil { t.Errorf("Error in lowLevelUncompressPairs for bb=%d: %v", bb, err) } @@ -390,9 +394,10 @@ func TestWriteReadPairsExtended(t *testing.T) { compressedWords := make([]int, MaxWords) // Loop over base bits 0 to 11. for bb := 0; bb <= 11; bb++ { - numWordsWritten := lowLevelCompressPairs(pairArray, numPairs, bb, compressedWords) + numWordsWritten, err := lowLevelCompressPairs(pairArray, numPairs, bb, compressedWords) + assert.NoError(t, err) t.Logf("Base bits: %d, words written: %d", bb, numWordsWritten) - err := lowLevelUncompressPairs(pairArray2, numPairs, bb, compressedWords, numWordsWritten) + err = lowLevelUncompressPairs(pairArray2, numPairs, bb, compressedWords, numWordsWritten) if err != nil { t.Errorf("Error in lowLevelUncompressPairs for base bits %d: %v", bb, err) }